Files
modelscope/modelscope/msdatasets/utils/hf_datasets_util.py
Xingjun.Wang 7d11b77112 Fix trust_remote_code (#1462)
1. Set `trust_remote_code` to `True` by default in datasets module
2. Set `trust_remote_code` to `True` by default in PolyLM pipeline
2025-08-14 10:56:16 +08:00

1499 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# noqa: isort:skip_file, yapf: disable
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
import importlib
import contextlib
import inspect
import os
import warnings
from dataclasses import dataclass, field, fields
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar
from urllib.parse import urlencode
import requests
from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
DownloadConfig, DownloadManager, DownloadMode, Features,
IterableDataset, IterableDatasetDict, Split,
VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf)
from datasets.features import features
from datasets.features.features import _FEATURE_TYPES
from datasets.data_files import (
FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError,
_get_data_files_patterns, _is_inside_unrequested_special_dir,
_is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, sanitize_patterns)
from datasets.download.streaming_download_manager import (
_prepare_path_and_storage_options, xbasename, xjoin)
from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
from datasets.info import DatasetInfosDict
from datasets.load import (
ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters,
CachedDatasetModuleFactory, DatasetModule,
HubDatasetModuleFactoryWithoutScript,
HubDatasetModuleFactoryWithParquetExport,
HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript,
LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory,
create_builder_configs_from_metadata_configs, get_dataset_builder_class,
import_main_class, infer_module_for_data_files, files_to_hash,
_get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file,
init_dynamic_modules)
from datasets.naming import camelcase_to_snakecase
from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
_MODULE_TO_EXTENSIONS,
_PACKAGED_DATASETS_MODULES)
from datasets.utils import file_utils
from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
cached_path, is_local_path,
is_relative_path,
relative_to_absolute_path)
from datasets.utils.info_utils import is_small_dataset
from datasets.utils.metadata import MetadataConfigs
from datasets.utils.py_utils import get_imports
from datasets.utils.track import tracked_str
from fsspec import filesystem
from fsspec.core import _un_chain
from fsspec.utils import stringify_path
from huggingface_hub import (DatasetCard, DatasetCardData)
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
from packaging import version
from modelscope import HubApi
from modelscope.hub.utils.utils import get_endpoint
from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET
from modelscope.utils.import_utils import has_attr_in_class
from modelscope.utils.logger import get_logger
logger = get_logger()
ExpandDatasetProperty_T = Literal[
'author',
'cardData',
'citation',
'createdAt',
'disabled',
'description',
'downloads',
'downloadsAllTime',
'gated',
'lastModified',
'likes',
'paperswithcode_id',
'private',
'siblings',
'sha',
'tags',
]
# Patch datasets features
@dataclass(repr=False)
class ListMs(SequenceHf):
"""Feature type for large list data composed of child feature data type.
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
Args:
feature ([`FeatureType`]):
Child feature data type of each item within the large list.
length (optional `int`, default to -1):
Length of the list if it is fixed.
Defaults to -1 which means an arbitrary length.
"""
feature: Any
length: int = -1
id: Optional[str] = field(default=None, repr=False)
# Automatically constructed
pa_type: ClassVar[Any] = None
_type: str = field(default='List', init=False, repr=False)
def __repr__(self):
if self.length != -1:
return f'{type(self).__name__}({self.feature}, length={self.length})'
else:
return f'{type(self).__name__}({self.feature})'
_FEATURE_TYPES['List'] = ListMs
def generate_from_dict_ms(obj: Any):
"""Regenerate the nested feature object from a deserialized dict.
We use the '_type' fields to get the dataclass name to load.
generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax
to define features from deserialized JSON dictionaries. This function is used in particular when deserializing
a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to
:meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any
mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes
that :class:`Value` automatically performs.
"""
# Nested structures: we allow dict, list/tuples, sequences
if isinstance(obj, list):
return [generate_from_dict_ms(value) for value in obj]
# Otherwise we have a dict or a dataclass
if '_type' not in obj or isinstance(obj['_type'], dict):
return {key: generate_from_dict_ms(value) for key, value in obj.items()}
obj = dict(obj)
_type = obj.pop('_type')
class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)
if class_type is None:
raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")
if class_type == LargeList:
feature = obj.pop('feature')
return LargeList(generate_from_dict_ms(feature), **obj)
if class_type == ListMs:
feature = obj.pop('feature')
return ListMs(generate_from_dict_ms(feature), **obj)
if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict
feature = obj.pop('feature')
return SequenceHf(feature=generate_from_dict_ms(feature), **obj)
field_names = {f.name for f in fields(class_type)}
return class_type(**{k: v for k, v in obj.items() if k in field_names})
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
# for temp val
revision = None
if url_or_filename.startswith('hf://'):
revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1)
if is_relative_path(url_or_filename):
# append the relative path to the base_path
# url_or_filename = url_or_path_join(self._base_path, url_or_filename)
revision = revision or DEFAULT_DATASET_REVISION
# Note: make sure the FilePath is the last param
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
params: str = urlencode(params)
url_or_filename = self._base_path + params
out = cached_path(url_or_filename, download_config=download_config)
out = tracked_str(out)
out.set_origin(url_or_filename)
return out
def _dataset_info(
self,
repo_id: str,
*,
revision: Optional[str] = None,
timeout: Optional[float] = None,
files_metadata: bool = False,
token: Optional[Union[bool, str]] = None,
expand: Optional[List[ExpandDatasetProperty_T]] = None,
) -> HfDatasetInfo:
"""
Get info on one specific dataset on huggingface.co.
Dataset can be private if you pass an acceptable token.
Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
revision (`str`, *optional*):
The revision of the dataset repository from which to get the
information.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
files_metadata (`bool`, *optional*):
Whether or not to retrieve metadata for files in the repository
(size, LFS metadata, etc). Defaults to `False`.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
Returns:
[`hf_api.DatasetInfo`]: The dataset repository information.
<Tip>
Raises the following errors:
- [`~utils.RepositoryNotFoundError`]
If the repository to download from cannot be found. This may be because it doesn't exist,
or because it is set to `private` and you do not have access.
- [`~utils.RevisionNotFoundError`]
If the revision to download from cannot be found.
</Tip>
"""
# Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
repo_info_iter = self.list_repo_tree(
repo_id=repo_id,
path_in_repo='/',
revision=revision,
recursive=False,
expand=expand,
token=token,
repo_type=REPO_TYPE_DATASET,
)
# Update data_info
data_info = dict({})
data_info['id'] = repo_id
data_info['private'] = False
data_info['author'] = repo_id.split('/')[0] if repo_id else None
data_info['sha'] = revision
data_info['lastModified'] = None
data_info['gated'] = False
data_info['disabled'] = False
data_info['downloads'] = 0
data_info['likes'] = 0
data_info['tags'] = []
data_info['cardData'] = []
data_info['createdAt'] = None
# e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}}
data_siblings = []
for info_item in repo_info_iter:
if isinstance(info_item, RepoFile):
data_siblings.append(
dict(
rfilename=info_item.rfilename,
blobId=info_item.blob_id,
size=info_item.size,
)
)
data_info['siblings'] = data_siblings
return HfDatasetInfo(**data_info)
def _list_repo_tree(
self,
repo_id: str,
path_in_repo: Optional[str] = None,
*,
recursive: bool = True,
expand: bool = False,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> Iterable[Union[RepoFile, RepoFolder]]:
_api = HubApi(timeout=3 * 60, max_retries=3)
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
# List all files in the repo
page_number = 1
page_size = 100
while True:
try:
dataset_files = _api.get_dataset_files(
repo_id=repo_id,
revision=revision or DEFAULT_DATASET_REVISION,
root_path=path_in_repo or '/',
recursive=recursive,
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
)
except Exception as e:
logger.error(f'Get dataset: {repo_id} file list failed, message: {e}')
break
for file_info_d in dataset_files:
path_info = {}
path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
path_info['path'] = file_info_d['Path']
path_info['size'] = file_info_d['Size']
path_info['oid'] = file_info_d['Sha256']
yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
if len(dataset_files) < page_size:
break
page_number += 1
def _get_paths_info(
self,
repo_id: str,
paths: Union[List[str], str],
*,
expand: bool = False,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> List[Union[RepoFile, RepoFolder]]:
# Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
repo_info_iter = self.list_repo_tree(
repo_id=repo_id,
recursive=False,
expand=expand,
revision=revision,
repo_type=repo_type,
token=token,
)
return [item_info for item_info in repo_info_iter]
def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):
_api = HubApi()
_namespace, _dataset_name = repo_id.split('/')
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
if download_config and download_config.download_desc is None:
download_config.download_desc = f'Downloading [{path_in_repo}]'
try:
url_or_filename = _api.get_dataset_file_url(
file_name=path_in_repo,
dataset_name=_dataset_name,
namespace=_namespace,
revision=revision,
extension_filter=False,
endpoint=endpoint
)
repo_file_path = cached_path(
url_or_filename=url_or_filename, download_config=download_config)
except FileNotFoundError as e:
repo_file_path = ''
logger.error(e)
return repo_file_path
def get_fs_token_paths(
urlpath,
storage_options=None,
protocol=None,
):
if isinstance(urlpath, (list, tuple, set)):
if not urlpath:
raise ValueError('empty urlpath sequence')
urlpath0 = stringify_path(list(urlpath)[0])
else:
urlpath0 = stringify_path(urlpath)
storage_options = storage_options or {}
if protocol:
storage_options['protocol'] = protocol
chain = _un_chain(urlpath0, storage_options or {})
inkwargs = {}
# Reverse iterate the chain, creating a nested target_* structure
for i, ch in enumerate(reversed(chain)):
urls, nested_protocol, kw = ch
if i == len(chain) - 1:
inkwargs = dict(**kw, **inkwargs)
continue
inkwargs['target_options'] = dict(**kw, **inkwargs)
inkwargs['target_protocol'] = nested_protocol
inkwargs['fo'] = urls
paths, protocol, _ = chain[0]
fs = filesystem(protocol, **inkwargs)
return fs
def _resolve_pattern(
pattern: str,
base_path: str,
allowed_extensions: Optional[List[str]] = None,
download_config: Optional[DownloadConfig] = None,
) -> List[str]:
"""
Resolve the paths and URLs of the data files from the pattern passed by the user.
You can use patterns to resolve multiple local files. Here are a few examples:
- *.csv to match all the CSV files at the first level
- **.csv to match all the CSV files at any level
- data/* to match all the files inside "data"
- data/** to match all the files inside "data" and its subdirectories
The patterns are resolved using the fsspec glob.
glob.glob, Path.glob, Path.match or fnmatch do not support ** with a prefix/suffix other than a forward slash /.
For instance, this means **.json is the same as *.json. On the contrary, the fsspec glob has no limits regarding the ** prefix/suffix, # noqa: E501
resulting in **.json being equivalent to **/*.json.
More generally:
- '*' matches any character except a forward-slash (to match just the file or directory name)
- '**' matches any character including a forward-slash /
Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
The same applies to special directories that start with a double underscore like "__pycache__".
You can still include one if the pattern explicitly mentions it:
- to include a hidden file: "*/.hidden.txt" or "*/.*"
- to include a hidden directory: ".hidden/*" or ".*/*"
- to include a special directory: "__special__/*" or "__*/*"
Example::
>>> from datasets.data_files import resolve_pattern
>>> base_path = "."
>>> resolve_pattern("docs/**/*.py", base_path)
[/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
Args:
pattern (str): Unix pattern or paths or URLs of the data files to resolve.
The paths can be absolute or relative to base_path.
Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
base_path (str): Base path to use when resolving relative paths.
allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
Returns:
List[str]: List of paths or URLs to the local or remote files that match the patterns.
"""
if is_relative_path(pattern):
pattern = xjoin(base_path, pattern)
elif is_local_path(pattern):
base_path = os.path.splitdrive(pattern)[0] + os.sep
else:
base_path = ''
# storage_options {'hf': {'token': None, 'endpoint': 'https://huggingface.co'}}
pattern, storage_options = _prepare_path_and_storage_options(
pattern, download_config=download_config)
fs = get_fs_token_paths(pattern, storage_options=storage_options)
fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker
fs_pattern = pattern.split('::')[0].split('://')[-1]
files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
protocol_prefix = protocol + '://' if protocol != 'file' else ''
glob_kwargs = {}
if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'):
# 10 times faster glob with detail=True (ignores costly info like lastCommit)
glob_kwargs['expand_info'] = False
try:
tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
except FileNotFoundError:
raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
matched_paths = [
filepath if filepath.startswith(protocol_prefix) else protocol_prefix
+ filepath for filepath, info in tmp_file_paths.items()
if info['type'] == 'file' and (
xbasename(filepath) not in files_to_ignore)
and not _is_inside_unrequested_special_dir(
os.path.relpath(filepath, fs_base_path),
os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504
not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504
os.path.relpath(filepath, fs_base_path),
os.path.relpath(fs_pattern, fs_base_path))
] # ignore .ipynb and __pycache__, but keep /../
if allowed_extensions is not None:
out = [
filepath for filepath in matched_paths
if any('.' + suffix in allowed_extensions
for suffix in xbasename(filepath).split('.')[1:])
]
if len(out) < len(matched_paths):
invalid_matched_files = list(set(matched_paths) - set(out))
logger.info(
f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: "
f'{invalid_matched_files}')
else:
out = matched_paths
if not out:
error_msg = f"Unable to find '{pattern}'"
if allowed_extensions is not None:
error_msg += f' with any supported extension {list(allowed_extensions)}'
raise FileNotFoundError(error_msg)
return out
def _get_data_patterns(
base_path: str,
download_config: Optional[DownloadConfig] = None) -> Dict[str,
List[str]]:
"""
Get the default pattern from a directory testing all the supported patterns.
The first patterns to return a non-empty list of data files is returned.
Some examples of supported patterns:
Input:
my_dataset_repository/
├── README.md
└── dataset.csv
Output:
{"train": ["**"]}
Input:
my_dataset_repository/
├── README.md
├── train.csv
└── test.csv
my_dataset_repository/
├── README.md
└── data/
├── train.csv
└── test.csv
my_dataset_repository/
├── README.md
├── train_0.csv
├── train_1.csv
├── train_2.csv
├── train_3.csv
├── test_0.csv
└── test_1.csv
Output:
{'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
Input:
my_dataset_repository/
├── README.md
└── data/
├── train/
│ ├── shard_0.csv
│ ├── shard_1.csv
│ ├── shard_2.csv
│ └── shard_3.csv
└── test/
├── shard_0.csv
└── shard_1.csv
Output:
{'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
Input:
my_dataset_repository/
├── README.md
└── data/
├── train-00000-of-00003.csv
├── train-00001-of-00003.csv
├── train-00002-of-00003.csv
├── test-00000-of-00001.csv
├── random-00000-of-00003.csv
├── random-00001-of-00003.csv
└── random-00002-of-00003.csv
Output:
{'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
"""
resolver = partial(
_resolve_pattern, base_path=base_path, download_config=download_config)
try:
return _get_data_files_patterns(resolver)
except FileNotFoundError:
raise EmptyDatasetError(
f"The directory at {base_path} doesn't contain any data files"
) from None
def get_module_without_script(self) -> DatasetModule:
# hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
# self.name,
# revision=self.revision,
# token=self.download_config.token,
# timeout=100.0,
# )
# even if metadata_configs is not None (which means that we will resolve files for each config later)
# we cannot skip resolving all files because we need to infer module name by files extensions
# revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime
revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
'/')
repo_id: str = self.name
download_config = self.download_config.copy()
dataset_readme_path = _download_repo_file(
repo_id=repo_id,
path_in_repo='README.md',
download_config=download_config,
revision=revision)
dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data if dataset_readme_path else DatasetCardData()
subset_name: str = download_config.storage_options.get('name', None)
metadata_configs = MetadataConfigs.from_dataset_card_data(
dataset_card_data)
dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
# we need a set of data files to find which dataset builder to use
# because we need to infer module name by files extensions
if self.data_files is not None:
patterns = sanitize_patterns(self.data_files)
elif metadata_configs and 'data_files' in next(
iter(metadata_configs.values())):
if subset_name is not None:
subset_data_files = metadata_configs[subset_name]['data_files']
else:
subset_data_files = next(iter(metadata_configs.values()))['data_files']
patterns = sanitize_patterns(subset_data_files)
else:
patterns = _get_data_patterns(
base_path, download_config=self.download_config)
data_files = DataFilesDict.from_patterns(
patterns,
base_path=base_path,
allowed_extensions=ALL_ALLOWED_EXTENSIONS,
download_config=self.download_config,
)
module_name, default_builder_kwargs = infer_module_for_data_files(
data_files=data_files,
path=self.name,
download_config=self.download_config,
)
if hasattr(data_files, 'filter'):
data_files = data_files.filter(extensions=_MODULE_TO_EXTENSIONS[module_name])
else:
data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name])
module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
if metadata_configs:
supports_metadata = module_name in {'imagefolder', 'audiofolder'}
create_builder_signature = inspect.signature(create_builder_configs_from_metadata_configs)
in_args = {
'module_path': module_path,
'metadata_configs': metadata_configs,
'base_path': base_path,
'default_builder_kwargs': default_builder_kwargs,
'download_config': self.download_config,
}
if 'supports_metadata' in create_builder_signature.parameters:
in_args['supports_metadata'] = supports_metadata
builder_configs, default_config_name = create_builder_configs_from_metadata_configs(**in_args)
else:
builder_configs: List[BuilderConfig] = [
import_main_class(module_path).BUILDER_CONFIG_CLASS(
data_files=data_files,
**default_builder_kwargs,
)
]
default_config_name = None
_api = HubApi()
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
builder_kwargs = {
# "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
'base_path':
HubApi().get_file_base_path(repo_id=repo_id, endpoint=endpoint),
'repo_id':
self.name,
'dataset_name':
camelcase_to_snakecase(Path(self.name).name),
'data_files': data_files,
}
download_config = self.download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = 'Downloading metadata'
# Note: `dataset_infos.json` is deprecated and can cause an error during loading if it exists
if default_config_name is None and len(dataset_infos) == 1:
default_config_name = next(iter(dataset_infos))
hash = revision
return DatasetModule(
module_path,
hash,
builder_kwargs,
dataset_infos=dataset_infos,
builder_configs_parameters=BuilderConfigsParameters(
metadata_configs=metadata_configs,
builder_configs=builder_configs,
default_config_name=default_config_name,
),
)
def _download_additional_modules(
name: str,
dataset_name: str,
namespace: str,
revision: str,
imports: Tuple[str, str, str, str],
download_config: Optional[DownloadConfig]
) -> List[Tuple[str, str]]:
"""
Download additional module for a module <name>.py at URL (or local path) <base_path>/<name>.py
The imports must have been parsed first using ``get_imports``.
If some modules need to be installed with pip, an error is raised showing how to install them.
This function return the list of downloaded modules as tuples (import_name, module_file_path).
The downloaded modules can then be moved into an importable directory
with ``_copy_script_and_other_resources_in_importable_dir``.
"""
local_imports = []
library_imports = []
download_config = download_config.copy()
if download_config.download_desc is None:
download_config.download_desc = 'Downloading extra modules'
for import_type, import_name, import_path, sub_directory in imports:
if import_type == 'library':
library_imports.append((import_name, import_path)) # Import from a library
continue
if import_name == name:
raise ValueError(
f'Error in the {name} script, importing relative {import_name} module '
f'but {import_name} is the name of the script. '
f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
f'comment pointing to the original relative import file path.'
)
if import_type == 'internal':
_api = HubApi()
# url_or_filename = url_or_path_join(base_path, import_path + ".py")
file_name = import_path + '.py'
url_or_filename = _api.get_dataset_file_url(file_name=file_name,
dataset_name=dataset_name,
namespace=namespace,
revision=revision,)
elif import_type == 'external':
url_or_filename = import_path
else:
raise ValueError('Wrong import_type')
local_import_path = cached_path(
url_or_filename,
download_config=download_config,
)
if sub_directory is not None:
local_import_path = os.path.join(local_import_path, sub_directory)
local_imports.append((import_name, local_import_path))
# Check library imports
needs_to_be_installed = {}
for library_import_name, library_import_path in library_imports:
try:
lib = importlib.import_module(library_import_name) # noqa F841
except ImportError:
if library_import_name not in needs_to_be_installed or library_import_path != library_import_name:
needs_to_be_installed[library_import_name] = library_import_path
if needs_to_be_installed:
_dependencies_str = 'dependencies' if len(needs_to_be_installed) > 1 else 'dependency'
_them_str = 'them' if len(needs_to_be_installed) > 1 else 'it'
if 'sklearn' in needs_to_be_installed.keys():
needs_to_be_installed['sklearn'] = 'scikit-learn'
if 'Bio' in needs_to_be_installed.keys():
needs_to_be_installed['Bio'] = 'biopython'
raise ImportError(
f'To be able to use {name}, you need to install the following {_dependencies_str}: '
f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install "
f"{' '.join(needs_to_be_installed.values())}' for instance."
)
return local_imports
def get_module_with_script(self) -> DatasetModule:
repo_id: str = self.name
_namespace, _dataset_name = repo_id.split('/')
revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
script_file_name = f'{_dataset_name}.py'
local_script_path = _download_repo_file(
repo_id=repo_id,
path_in_repo=script_file_name,
download_config=self.download_config,
revision=revision,
)
if not local_script_path:
raise FileNotFoundError(
f'Cannot find {script_file_name} in {repo_id} at revision {revision}. '
f'Please create {script_file_name} in the repo.'
)
dataset_infos_path = None
# try:
# dataset_infos_url: str = _api.get_dataset_file_url(
# file_name='dataset_infos.json',
# dataset_name=_dataset_name,
# namespace=_namespace,
# revision=self.revision,
# extension_filter=False,
# )
# dataset_infos_path = cached_path(
# url_or_filename=dataset_infos_url, download_config=self.download_config)
# except Exception as e:
# logger.info(f'Cannot find dataset_infos.json: {e}')
# dataset_infos_path = None
dataset_readme_path = _download_repo_file(
repo_id=repo_id,
path_in_repo='README.md',
download_config=self.download_config,
revision=revision
)
imports = get_imports(local_script_path)
local_imports = _download_additional_modules(
name=repo_id,
dataset_name=_dataset_name,
namespace=_namespace,
revision=revision,
imports=imports,
download_config=self.download_config,
)
additional_files = []
if dataset_infos_path:
additional_files.append((config.DATASETDICT_INFOS_FILENAME, dataset_infos_path))
if dataset_readme_path:
additional_files.append((config.REPOCARD_FILENAME, dataset_readme_path))
# copy the script and the files in an importable directory
dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
hash = files_to_hash([local_script_path] + [loc[1] for loc in local_imports])
importable_file_path = _get_importable_file_path(
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=repo_id,
)
if not os.path.exists(importable_file_path):
trust_remote_code = resolve_trust_remote_code(trust_remote_code=self.trust_remote_code, repo_id=self.name)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {repo_id}. Please make sure that '
'you can trust the external codes.')
_create_importable_file(
local_path=local_script_path,
local_imports=local_imports,
additional_files=additional_files,
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=repo_id,
download_mode=self.download_mode,
)
else:
raise ValueError(
f'Loading {repo_id} requires you to execute the dataset script in that'
' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
' set the option `trust_remote_code=True` to remove this error.'
)
module_path, hash = _load_importable_file(
dynamic_modules_path=dynamic_modules_path,
module_namespace='datasets',
subdirectory_name=hash,
name=repo_id,
)
# make the new module to be noticed by the import system
importlib.invalidate_caches()
builder_kwargs = {
# "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
'base_path': HubApi().get_file_base_path(repo_id=repo_id),
'repo_id': repo_id,
}
return DatasetModule(module_path, hash, builder_kwargs)
class DatasetsWrapperHF:
@staticmethod
def load_dataset(
path: str,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None,
split: Optional[Union[str, Split]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
verification_mode: Optional[Union[VerificationMode, str]] = None,
keep_in_memory: Optional[bool] = None,
save_infos: bool = False,
revision: Optional[Union[str, Version]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token='deprecated',
task='deprecated',
streaming: bool = False,
num_proc: Optional[int] = None,
storage_options: Optional[Dict] = None,
trust_remote_code: bool = False,
dataset_info_only: Optional[bool] = False,
**config_kwargs,
) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset,
dict]:
if use_auth_token != 'deprecated':
warnings.warn(
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
"You can remove this warning by passing 'token=<use_auth_token>' instead.",
FutureWarning,
)
token = use_auth_token
if task != 'deprecated':
warnings.warn(
"'task' was deprecated in version 2.13.0 and will be removed in 3.0.0.\n",
FutureWarning,
)
else:
task = None
if data_files is not None and not data_files:
raise ValueError(
f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default)."
)
if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(
):
raise ValueError(
'You are trying to load a dataset that was saved using `save_to_disk`. '
'Please use `load_from_disk` instead.')
if streaming and num_proc is not None:
raise NotImplementedError(
'Loading a streaming dataset in parallel with `num_proc` is not implemented. '
'To parallelize streaming, you can wrap the dataset with a PyTorch DataLoader '
'using `num_workers` > 1 instead.')
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
verification_mode = VerificationMode((
verification_mode or VerificationMode.BASIC_CHECKS
) if not save_infos else VerificationMode.ALL_CHECKS)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
# Create a dataset builder
builder_instance = DatasetsWrapperHF.load_dataset_builder(
path=path,
name=name,
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_dir,
features=features,
download_config=download_config,
download_mode=download_mode,
revision=revision,
token=token,
storage_options=storage_options,
trust_remote_code=trust_remote_code,
_require_default_config_name=name is None,
**config_kwargs,
)
# Note: Only for preview mode
if dataset_info_only:
ret_dict = {}
# Get dataset config info from python script
if isinstance(path, str) and path.endswith('.py') and os.path.exists(path):
from datasets import get_dataset_config_names
subset_list = get_dataset_config_names(path)
ret_dict = {_subset: [] for _subset in subset_list}
return ret_dict
if builder_instance is None or not hasattr(builder_instance,
'builder_configs'):
logger.error(f'No builder_configs found for {path} dataset.')
return ret_dict
_tmp_builder_configs = builder_instance.builder_configs
for tmp_config_name, tmp_builder_config in _tmp_builder_configs.items():
tmp_config_name = str(tmp_config_name)
if hasattr(tmp_builder_config, 'data_files') and tmp_builder_config.data_files is not None:
ret_dict[tmp_config_name] = [str(item) for item in list(tmp_builder_config.data_files.keys())]
else:
ret_dict[tmp_config_name] = []
return ret_dict
# Return iterable dataset in case of streaming
if streaming:
return builder_instance.as_streaming_dataset(split=split)
# Some datasets are already processed on the HF google storage
# Don't try downloading from Google storage for the packaged datasets as text, json, csv or pandas
# try_from_hf_gcs = path not in _PACKAGED_DATASETS_MODULES
# Download and prepare data
builder_instance.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
num_proc=num_proc,
storage_options=storage_options,
# base_path=builder_instance.base_path,
# file_format=builder_instance.name or 'arrow',
)
# Build dataset for splits
keep_in_memory = (
keep_in_memory if keep_in_memory is not None else is_small_dataset(
builder_instance.info.dataset_size))
ds = builder_instance.as_dataset(
split=split,
verification_mode=verification_mode,
in_memory=keep_in_memory)
# Rename and cast features to match task schema
if task is not None:
# To avoid issuing the same warning twice
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
ds = ds.prepare_for_task(task)
if save_infos:
builder_instance._save_infos()
try:
_api = HubApi()
if is_relative_path(path) and path.count('/') == 1:
_namespace, _dataset_name = path.split('/')
endpoint = _api.get_endpoint_for_read(
repo_id=path, repo_type=REPO_TYPE_DATASET)
_api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
except Exception as e:
logger.warning(f'Could not record download statistics: {e}')
return ds
@staticmethod
def load_dataset_builder(
path: str,
name: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[str, Sequence[str],
Mapping[str, Union[str,
Sequence[str]]]]] = None,
cache_dir: Optional[str] = None,
features: Optional[Features] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
revision: Optional[Union[str, Version]] = None,
token: Optional[Union[bool, str]] = None,
use_auth_token='deprecated',
storage_options: Optional[Dict] = None,
trust_remote_code: Optional[bool] = None,
_require_default_config_name=True,
**config_kwargs,
) -> DatasetBuilder:
if use_auth_token != 'deprecated':
warnings.warn(
"'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
"You can remove this warning by passing 'token=<use_auth_token>' instead.",
FutureWarning,
)
token = use_auth_token
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
if token is not None:
download_config = download_config.copy(
) if download_config else DownloadConfig()
download_config.token = token
if storage_options is not None:
download_config = download_config.copy(
) if download_config else DownloadConfig()
download_config.storage_options.update(storage_options)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
dataset_module = DatasetsWrapperHF.dataset_module_factory(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
_require_default_config_name=_require_default_config_name,
_require_custom_configs=bool(config_kwargs),
name=name,
)
# Get dataset builder class from the processing script
builder_kwargs = dataset_module.builder_kwargs
data_dir = builder_kwargs.pop('data_dir', data_dir)
data_files = builder_kwargs.pop('data_files', data_files)
config_name = builder_kwargs.pop(
'config_name', name
or dataset_module.builder_configs_parameters.default_config_name)
dataset_name = builder_kwargs.pop('dataset_name', None)
info = dataset_module.dataset_infos.get(
config_name) if dataset_module.dataset_infos else None
if (path in _PACKAGED_DATASETS_MODULES and data_files is None
and dataset_module.builder_configs_parameters.
builder_configs[0].data_files is None):
error_msg = f'Please specify the data files or data directory to load for the {path} dataset builder.'
example_extensions = [
extension for extension in _EXTENSION_TO_MODULE
if _EXTENSION_TO_MODULE[extension] == path
]
if example_extensions:
error_msg += f'\nFor example `data_files={{"train": "path/to/data/train/*.{example_extensions[0]}"}}`'
raise ValueError(error_msg)
builder_cls = get_dataset_builder_class(
dataset_module, dataset_name=dataset_name)
builder_instance: DatasetBuilder = builder_cls(
cache_dir=cache_dir,
dataset_name=dataset_name,
config_name=config_name,
data_dir=data_dir,
data_files=data_files,
hash=dataset_module.hash,
info=info,
features=features,
token=token,
storage_options=storage_options,
**builder_kwargs, # contains base_path
**config_kwargs,
)
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
return builder_instance
@staticmethod
def dataset_module_factory(
path: str,
revision: Optional[Union[str, Version]] = None,
download_config: Optional[DownloadConfig] = None,
download_mode: Optional[Union[DownloadMode, str]] = None,
dynamic_modules_path: Optional[str] = None,
data_dir: Optional[str] = None,
data_files: Optional[Union[Dict, List, str, DataFilesDict]] = None,
cache_dir: Optional[str] = None,
trust_remote_code: Optional[bool] = None,
_require_default_config_name=True,
_require_custom_configs=False,
**download_kwargs,
) -> DatasetModule:
subset_name: str = download_kwargs.pop('name', None)
revision = revision or DEFAULT_DATASET_REVISION
if download_config is None:
download_config = DownloadConfig(**download_kwargs)
download_config.storage_options.update({'name': subset_name})
download_config.storage_options.update({'revision': revision})
if download_config and download_config.cache_dir is None:
download_config.cache_dir = MS_DATASETS_CACHE
download_mode = DownloadMode(download_mode
or DownloadMode.REUSE_DATASET_IF_EXISTS)
download_config.extract_compressed_file = True
download_config.force_extract = True
download_config.force_download = download_mode == DownloadMode.FORCE_REDOWNLOAD
filename = list(
filter(lambda x: x,
path.replace(os.sep, '/').split('/')))[-1]
if not filename.endswith('.py'):
filename = filename + '.py'
combined_path = os.path.join(path, filename)
# We have several ways to get a dataset builder:
#
# - if path is the name of a packaged dataset module
# -> use the packaged module (json, csv, etc.)
#
# - if os.path.join(path, name) is a local python file
# -> use the module from the python file
# - if path is a local directory (but no python file)
# -> use a packaged module (csv, text etc.) based on content of the directory
#
# - if path has one "/" and is dataset repository on the HF hub with a python file
# -> the module from the python file in the dataset repository
# - if path has one "/" and is dataset repository on the HF hub without a python file
# -> use a packaged module (csv, text etc.) based on content of the repository
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
# Try packaged
if path in _PACKAGED_DATASETS_MODULES:
return PackagedDatasetModuleFactory(
path,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
# Try locally
elif path.endswith(filename):
if os.path.isfile(path):
return LocalDatasetModuleFactoryWithScript(
path,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
else:
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
)
elif os.path.isfile(combined_path):
return LocalDatasetModuleFactoryWithScript(
combined_path,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
elif os.path.isdir(path):
return LocalDatasetModuleFactoryWithoutScript(
path,
data_dir=data_dir,
data_files=data_files,
download_mode=download_mode).get_module()
# Try remotely
elif is_relative_path(path) and path.count('/') == 1:
try:
_raise_if_offline_mode_is_enabled()
try:
dataset_info = HfApi().dataset_info(
repo_id=path,
revision=revision,
token=download_config.token,
timeout=100.0,
)
except Exception as e: # noqa catch any exception of hf_hub and consider that the dataset doesn't exist
if isinstance(
e,
( # noqa: E131
OfflineModeIsEnabled, # noqa: E131
requests.exceptions.
ConnectTimeout, # noqa: E131, E261
requests.exceptions.ConnectionError, # noqa: E131
), # noqa: E131
):
raise ConnectionError(
f"Couldn't reach '{path}' on the Hub ({type(e).__name__})"
)
elif '404' in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
raise DatasetNotFoundError(
msg
+ f" at revision '{revision}'" if revision else msg
)
elif '401' in str(e):
msg = f"Dataset '{path}' doesn't exist on the Hub"
msg = msg + f" at revision '{revision}'" if revision else msg
raise DatasetNotFoundError(
msg + '. If the repo is private or gated, '
'make sure to log in with `huggingface-cli login`.'
)
else:
raise e
dataset_readme_path = _download_repo_file(
repo_id=path,
path_in_repo='README.md',
download_config=download_config,
revision=revision,
)
commit_hash = os.path.basename(os.path.dirname(dataset_readme_path))
if filename in [
sibling.rfilename for sibling in dataset_info.siblings
]: # contains a dataset script
# fs = HfFileSystem(
# endpoint=config.HF_ENDPOINT,
# token=download_config.token)
# TODO
can_load_config_from_parquet_export = False
# if _require_custom_configs:
# can_load_config_from_parquet_export = False
# elif _require_default_config_name:
# with fs.open(
# f'datasets/{path}/{filename}',
# 'r',
# revision=revision,
# encoding='utf-8') as f:
# can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read(
# )
# else:
# can_load_config_from_parquet_export = True
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
# If the parquet export is ready (parquet files + info available for the current sha),
# we can use it instead
# This fails when the dataset has multiple configs and a default config and
# the user didn't specify a configuration name (_require_default_config_name=True).
try:
if has_attr_in_class(HubDatasetModuleFactoryWithParquetExport, 'revision'):
return HubDatasetModuleFactoryWithParquetExport(
path,
revision=revision,
download_config=download_config).get_module()
return HubDatasetModuleFactoryWithParquetExport(
path,
commit_hash=commit_hash,
download_config=download_config).get_module()
except Exception as e:
logger.error(e)
# Otherwise we must use the dataset script if the user trusts it
# To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
return HubDatasetModuleFactoryWithScript(
path,
revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
return HubDatasetModuleFactoryWithScript(
path,
commit_hash=commit_hash,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
else:
# To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithoutScript, 'revision'):
return HubDatasetModuleFactoryWithoutScript(
path,
revision=revision,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
return HubDatasetModuleFactoryWithoutScript(
path,
commit_hash=commit_hash,
data_dir=data_dir,
data_files=data_files,
download_config=download_config,
download_mode=download_mode,
).get_module()
except Exception as e1:
# All the attempts failed, before raising the error we should check if the module is already cached
logger.error(f'>> Error loading {path}: {e1}')
try:
return CachedDatasetModuleFactory(
path,
dynamic_modules_path=dynamic_modules_path,
cache_dir=cache_dir).get_module()
except Exception:
# If it's not in the cache, then it doesn't exist.
if isinstance(e1, OfflineModeIsEnabled):
raise ConnectionError(
f"Couldn't reach the Hugging Face Hub for dataset '{path}': {e1}"
) from None
if isinstance(e1,
(DataFilesNotFoundError,
DatasetNotFoundError, EmptyDatasetError)):
raise e1 from None
if isinstance(e1, FileNotFoundError):
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
f'any data file in the same directory. '
f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
) from None
raise e1 from None
else:
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
f'any data file in the same directory.')
@contextlib.contextmanager
def load_dataset_with_ctx(*args, **kwargs):
# Keep the original functions
hf_endpoint_origin = config.HF_ENDPOINT
get_from_cache_origin = file_utils.get_from_cache
# Compatible with datasets 2.18.0
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
else DownloadManager._download_single
dataset_info_origin = HfApi.dataset_info
list_repo_tree_origin = HfApi.list_repo_tree
get_paths_info_origin = HfApi.get_paths_info
resolve_pattern_origin = data_files.resolve_pattern
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
generate_from_dict_origin = features.generate_from_dict
# Monkey patching with modelscope functions
config.HF_ENDPOINT = get_endpoint()
file_utils.get_from_cache = get_from_cache_ms
# Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_ms
else:
DownloadManager._download_single = _download_ms
HfApi.dataset_info = _dataset_info
HfApi.list_repo_tree = _list_repo_tree
HfApi.get_paths_info = _get_paths_info
data_files.resolve_pattern = _resolve_pattern
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
features.generate_from_dict = generate_from_dict_ms
streaming = kwargs.get('streaming', False)
try:
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
yield dataset_res
finally:
# Restore the original functions
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin
# Keep the context during the streaming iteration
if not streaming:
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
# Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_origin
else:
DownloadManager._download_single = _download_origin
HfApi.dataset_info = dataset_info_origin
HfApi.list_repo_tree = list_repo_tree_origin
HfApi.get_paths_info = get_paths_info_origin
data_files.resolve_pattern = resolve_pattern_origin
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin