diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py
index ac66e11c..5c8599b0 100644
--- a/modelscope/hub/api.py
+++ b/modelscope/hub/api.py
@@ -15,7 +15,9 @@ from http import HTTPStatus
from http.cookiejar import CookieJar
from os.path import expanduser
from typing import Dict, List, Optional, Tuple, Union
+from urllib.parse import urlencode
+import json
import pandas as pd
import requests
from requests import Session
@@ -31,7 +33,8 @@ from modelscope.hub.constants import (API_HTTP_CLIENT_TIMEOUT,
MODELSCOPE_CLOUD_ENVIRONMENT,
MODELSCOPE_CLOUD_USERNAME,
MODELSCOPE_REQUEST_ID, ONE_YEAR_SECONDS,
- REQUESTS_API_HTTP_METHOD, Licenses,
+ REQUESTS_API_HTTP_METHOD,
+ DatasetVisibility, Licenses,
ModelVisibility)
from modelscope.hub.errors import (InvalidParameter, NotExistError,
NotLoginException, NoValidRevisionError,
@@ -647,6 +650,44 @@ class HubApi:
files.append(file)
return files
+ def create_dataset(self,
+ dataset_name: str,
+ namespace: str,
+ chinese_name: Optional[str] = '',
+ license: Optional[str] = Licenses.APACHE_V2,
+ visibility: Optional[int] = DatasetVisibility.PUBLIC,
+ description: Optional[str] = '') -> str:
+
+ if dataset_name is None or namespace is None:
+ raise InvalidParameter('dataset_name and namespace are required!')
+
+ cookies = ModelScopeConfig.get_cookies()
+ if cookies is None:
+ raise ValueError('Token does not exist, please login first.')
+
+ path = f'{self.endpoint}/api/v1/datasets'
+ files = {
+ 'Name': (None, dataset_name),
+ 'ChineseName': (None, chinese_name),
+ 'Owner': (None, namespace),
+ 'License': (None, license),
+ 'Visibility': (None, visibility),
+ 'Description': (None, description)
+ }
+
+ r = self.session.post(
+ path,
+ files=files,
+ cookies=cookies,
+ headers=self.builder_headers(self.headers),
+ )
+
+ handle_http_post_error(r, path, files)
+ raise_on_error(r.json())
+ dataset_repo_url = f'{self.endpoint}/datasets/{namespace}/{dataset_name}'
+ logger.info(f'Create dataset success: {dataset_repo_url}')
+ return dataset_repo_url
+
def list_datasets(self):
path = f'{self.endpoint}/api/v1/datasets'
params = {}
@@ -667,6 +708,47 @@ class HubApi:
dataset_type = resp['Data']['Type']
return dataset_id, dataset_type
+ def get_dataset_infos(self,
+ dataset_hub_id: str,
+ revision: str,
+ files_metadata: bool = False,
+ timeout: float = 100,
+ recursive: str = 'True'):
+ """
+ Get dataset infos.
+ """
+ datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
+ params = {'Revision': revision, 'Root': None, 'Recursive': recursive}
+ cookies = ModelScopeConfig.get_cookies()
+ if files_metadata:
+ params['blobs'] = True
+ r = self.session.get(datahub_url, params=params, cookies=cookies, timeout=timeout)
+ resp = r.json()
+ datahub_raise_on_error(datahub_url, resp, r)
+
+ return resp
+
+ def list_repo_tree(self,
+ dataset_name: str,
+ namespace: str,
+ revision: str,
+ root_path: str,
+ recursive: bool = True):
+
+ dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
+ dataset_name=dataset_name, namespace=namespace)
+
+ recursive = 'True' if recursive else 'False'
+ datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_hub_id}/repo/tree'
+ params = {'Revision': revision, 'Root': root_path, 'Recursive': recursive}
+ cookies = ModelScopeConfig.get_cookies()
+
+ r = self.session.get(datahub_url, params=params, cookies=cookies)
+ resp = r.json()
+ datahub_raise_on_error(datahub_url, resp, r)
+
+ return resp
+
def get_dataset_meta_file_list(self, dataset_name: str, namespace: str, dataset_id: str, revision: str):
""" Get the meta file-list of the dataset. """
datahub_url = f'{self.endpoint}/api/v1/datasets/{dataset_id}/repo/tree?Revision={revision}'
@@ -735,7 +817,6 @@ class HubApi:
Fetch the meta-data files from the url, e.g. csv/jsonl files.
"""
import hashlib
- import json
from tqdm import tqdm
out_path = os.path.join(out_path, hashlib.md5(url.encode(encoding='UTF-8')).hexdigest())
if mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(out_path):
@@ -774,7 +855,7 @@ class HubApi:
else:
with_header = False
chunk_df = pd.DataFrame(chunk)
- chunk_df.to_csv(f, index=False, header=with_header)
+ chunk_df.to_csv(f, index=False, header=with_header, escapechar='\\')
iter_num += 1
else:
# csv or others
@@ -789,11 +870,28 @@ class HubApi:
file_name: str,
dataset_name: str,
namespace: str,
- revision: Optional[str] = DEFAULT_DATASET_REVISION):
- if file_name and os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
- file_name = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?' \
- f'Revision={revision}&FilePath={file_name}'
- return file_name
+ revision: Optional[str] = DEFAULT_DATASET_REVISION,
+ extension_filter: Optional[bool] = True):
+
+ if not file_name or not dataset_name or not namespace:
+ raise ValueError('Args (file_name, dataset_name, namespace) cannot be empty!')
+
+ # Note: make sure the FilePath is the last parameter in the url
+ params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': file_name}
+ params: str = urlencode(params)
+ file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?{params}'
+
+ return file_url
+
+ # if extension_filter:
+ # if os.path.splitext(file_name)[-1] in META_FILES_FORMAT:
+ # file_url = f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'\
+ # f'Revision={revision}&FilePath={file_name}'
+ # else:
+ # file_url = file_name
+ # return file_url
+ # else:
+ # return file_url
def get_dataset_access_config(
self,
@@ -931,7 +1029,7 @@ class HubApi:
datahub_raise_on_error(url, resp, r)
return resp['Data']
- def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool) -> None:
+ def dataset_download_statistics(self, dataset_name: str, namespace: str, use_streaming: bool = False) -> None:
is_ci_test = os.getenv('CI_TEST') == 'True'
if dataset_name and namespace and not is_ci_test and not use_streaming:
try:
@@ -964,6 +1062,10 @@ class HubApi:
return {MODELSCOPE_REQUEST_ID: str(uuid.uuid4().hex),
**headers}
+ def get_file_base_path(self, namespace: str, dataset_name: str) -> str:
+ return f'{self.endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?'
+ # return f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}/repo?Revision={revision}&FilePath='
+
class ModelScopeConfig:
path_credential = expanduser(DEFAULT_CREDENTIALS_PATH)
diff --git a/modelscope/hub/constants.py b/modelscope/hub/constants.py
index 362f323d..9b443b71 100644
--- a/modelscope/hub/constants.py
+++ b/modelscope/hub/constants.py
@@ -47,3 +47,9 @@ class ModelVisibility(object):
PRIVATE = 1
INTERNAL = 3
PUBLIC = 5
+
+
+class DatasetVisibility(object):
+ PRIVATE = 1
+ INTERNAL = 3
+ PUBLIC = 5
diff --git a/modelscope/msdatasets/meta/data_meta_manager.py b/modelscope/msdatasets/meta/data_meta_manager.py
index 3f1e6572..4eb9942b 100644
--- a/modelscope/msdatasets/meta/data_meta_manager.py
+++ b/modelscope/msdatasets/meta/data_meta_manager.py
@@ -92,6 +92,10 @@ class DataMetaManager(object):
data_meta_config.meta_cache_dir = meta_cache_dir
data_meta_config.dataset_scripts = dataset_scripts
data_meta_config.dataset_formation = dataset_formation
+ if '.py' in dataset_scripts:
+ tmp_py_scripts = dataset_scripts['.py']
+ if len(tmp_py_scripts) > 0:
+ data_meta_config.dataset_py_script = tmp_py_scripts[0]
# Set dataset_context_config
self.dataset_context_config.data_meta_config = data_meta_config
diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py
index b720ada6..7d99a7cb 100644
--- a/modelscope/msdatasets/ms_dataset.py
+++ b/modelscope/msdatasets/ms_dataset.py
@@ -13,7 +13,6 @@ from datasets.utils.file_utils import is_relative_path
from modelscope.hub.repository import DatasetRepository
from modelscope.msdatasets.context.dataset_context_config import \
DatasetContextConfig
-from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
from modelscope.msdatasets.data_loader.data_loader_manager import (
LocalDataLoaderManager, LocalDataLoaderType, RemoteDataLoaderManager,
RemoteDataLoaderType)
@@ -22,14 +21,16 @@ from modelscope.msdatasets.dataset_cls import (ExternalDataset,
from modelscope.msdatasets.dataset_cls.custom_datasets.builder import \
build_custom_dataset
from modelscope.msdatasets.utils.delete_utils import DatasetDeleteManager
+from modelscope.msdatasets.utils.hf_datasets_util import \
+ load_dataset as hf_load_dataset_wrapper
from modelscope.msdatasets.utils.upload_utils import DatasetUploadManager
from modelscope.preprocessors import build_preprocessor
from modelscope.utils.config import Config, ConfigDict
from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
DEFAULT_DATASET_REVISION, ConfigFields,
- DownloadMode, Hubs, ModeKeys, Tasks,
- UploadMode, VirgoDatasetConfig)
+ DatasetFormations, DownloadMode, Hubs,
+ ModeKeys, Tasks, UploadMode)
from modelscope.utils.import_utils import is_tf_available, is_torch_available
from modelscope.utils.logger import get_logger
@@ -167,6 +168,7 @@ class MsDataset:
stream_batch_size: Optional[int] = 1,
custom_cfg: Optional[Config] = Config(),
token: Optional[str] = None,
+ dataset_info_only: Optional[bool] = False,
**config_kwargs,
) -> Union[dict, 'MsDataset', NativeIterableDataset]:
"""Load a MsDataset from the ModelScope Hub, Hugging Face Hub, urls, or a local dataset.
@@ -196,6 +198,7 @@ class MsDataset:
custom_cfg (str, Optional): Model configuration, this can be used for custom datasets.
see https://modelscope.cn/docs/Configuration%E8%AF%A6%E8%A7%A3
token (str, Optional): SDK token of ModelScope.
+ dataset_info_only (bool, Optional): If set to True, only return the dataset config and info (dict).
**config_kwargs (additional keyword arguments): Keyword arguments to be passed
Returns:
@@ -279,19 +282,51 @@ class MsDataset:
return dataset_inst
# Load from the modelscope hub
elif hub == Hubs.modelscope:
- remote_dataloader_manager = RemoteDataLoaderManager(
- dataset_context_config)
- dataset_inst = remote_dataloader_manager.load_dataset(
- RemoteDataLoaderType.MS_DATA_LOADER)
- dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
- if isinstance(dataset_inst, MsDataset):
- dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
- if custom_cfg:
- dataset_inst.to_custom_dataset(
- custom_cfg=custom_cfg, **config_kwargs)
- dataset_inst.is_custom = True
- return dataset_inst
+
+ # Get dataset type from ModelScope Hub; dataset_type->4: General Dataset
+ from modelscope.hub.api import HubApi
+ _api = HubApi()
+ dataset_id_on_hub, dataset_type = _api.get_dataset_id_and_type(
+ dataset_name=dataset_name, namespace=namespace)
+
+ logger.info(f'dataset_type: {dataset_type}')
+
+ # Load from the ModelScope Hub for type=4 (general)
+ if str(dataset_type) == str(DatasetFormations.general.value):
+ return hf_load_dataset_wrapper(
+ path=namespace + '/' + dataset_name,
+ name=subset_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ split=split,
+ cache_dir=cache_dir,
+ features=None,
+ download_config=None,
+ download_mode=download_mode.value,
+ revision=version,
+ token=token,
+ streaming=use_streaming,
+ dataset_info_only=dataset_info_only,
+ **config_kwargs)
+ else:
+
+ remote_dataloader_manager = RemoteDataLoaderManager(
+ dataset_context_config)
+ dataset_inst = remote_dataloader_manager.load_dataset(
+ RemoteDataLoaderType.MS_DATA_LOADER)
+ dataset_inst = MsDataset.to_ms_dataset(
+ dataset_inst, target=target)
+ if isinstance(dataset_inst, MsDataset):
+ dataset_inst._dataset_context_config = remote_dataloader_manager.dataset_context_config
+ if custom_cfg:
+ dataset_inst.to_custom_dataset(
+ custom_cfg=custom_cfg, **config_kwargs)
+ dataset_inst.is_custom = True
+ return dataset_inst
+
elif hub == Hubs.virgo:
+ from modelscope.msdatasets.data_loader.data_loader import VirgoDownloader
+ from modelscope.utils.constant import VirgoDatasetConfig
# Rewrite the namespace, version and cache_dir for virgo dataset.
if namespace == DEFAULT_DATASET_NAMESPACE:
dataset_context_config.namespace = VirgoDatasetConfig.default_virgo_namespace
@@ -323,6 +358,10 @@ class MsDataset:
chunksize: Optional[int] = 1,
filter_hidden_files: Optional[bool] = True,
upload_mode: Optional[UploadMode] = UploadMode.OVERWRITE) -> None:
+ r"""
+ @deprecated
+ This method is deprecated and may be removed in future releases, please use git command line instead.
+ """
"""Upload dataset file or directory to the ModelScope Hub. Please log in to the ModelScope Hub first.
Args:
@@ -346,6 +385,10 @@ class MsDataset:
None
"""
+ warnings.warn(
+ 'upload is deprecated, please use git command line to upload the dataset.',
+ DeprecationWarning)
+
if not object_name:
raise ValueError('object_name cannot be empty!')
@@ -393,6 +436,10 @@ class MsDataset:
None
"""
+ warnings.warn(
+ 'upload is deprecated, please use git command line to upload the dataset.',
+ DeprecationWarning)
+
_repo = DatasetRepository(
repo_work_dir=dataset_work_dir,
dataset_id=dataset_id,
diff --git a/modelscope/msdatasets/utils/dataset_utils.py b/modelscope/msdatasets/utils/dataset_utils.py
index b40915eb..6d939ef1 100644
--- a/modelscope/msdatasets/utils/dataset_utils.py
+++ b/modelscope/msdatasets/utils/dataset_utils.py
@@ -212,7 +212,10 @@ def get_dataset_files(subset_split_into: dict,
csv_delimiter = context_config.config_kwargs.get('delimiter', ',')
csv_df = pd.read_csv(
- meta_csv_file_path, iterator=False, delimiter=csv_delimiter)
+ meta_csv_file_path,
+ iterator=False,
+ delimiter=csv_delimiter,
+ escapechar='\\')
target_col = csv_df.columns[csv_df.columns.str.contains(
':FILE')].to_list()
if len(target_col) == 0:
diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py
new file mode 100644
index 00000000..8b067fda
--- /dev/null
+++ b/modelscope/msdatasets/utils/hf_datasets_util.py
@@ -0,0 +1,1339 @@
+# noqa: isort:skip_file, yapf: disable
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
+import importlib
+import os
+import warnings
+from functools import partial
+from pathlib import Path
+from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple
+
+from urllib.parse import urlencode
+
+import requests
+from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
+ DownloadConfig, DownloadManager, DownloadMode, Features,
+ IterableDataset, IterableDatasetDict, Split,
+ VerificationMode, Version, config, data_files)
+from datasets.data_files import (
+ FILES_TO_IGNORE, DataFilesDict, DataFilesList, EmptyDatasetError,
+ _get_data_files_patterns, _is_inside_unrequested_special_dir,
+ _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, get_metadata_patterns, sanitize_patterns)
+from datasets.download.streaming_download_manager import (
+ _prepare_path_and_storage_options, xbasename, xjoin)
+from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
+from datasets.info import DatasetInfosDict
+from datasets.load import (
+ ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters,
+ CachedDatasetModuleFactory, DatasetModule,
+ HubDatasetModuleFactoryWithoutScript,
+ HubDatasetModuleFactoryWithParquetExport,
+ HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript,
+ LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory,
+ create_builder_configs_from_metadata_configs, get_dataset_builder_class,
+ import_main_class, infer_module_for_data_files, files_to_hash,
+ _get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file,
+ init_dynamic_modules)
+from datasets.naming import camelcase_to_snakecase
+from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
+ _MODULE_SUPPORTS_METADATA,
+ _MODULE_TO_EXTENSIONS,
+ _PACKAGED_DATASETS_MODULES)
+from datasets.utils import _datasets_server, file_utils
+from datasets.utils.file_utils import (OfflineModeIsEnabled,
+ _raise_if_offline_mode_is_enabled,
+ cached_path, is_local_path,
+ is_relative_path,
+ relative_to_absolute_path)
+from datasets.utils.info_utils import is_small_dataset
+from datasets.utils.metadata import MetadataConfigs
+from datasets.utils.py_utils import get_imports
+from datasets.utils.track import tracked_str
+from fsspec import filesystem
+from fsspec.core import _un_chain
+from fsspec.utils import stringify_path
+from huggingface_hub import (DatasetCard, DatasetCardData, HfFileSystem)
+from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
+from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
+from packaging import version
+
+from modelscope import HubApi
+from modelscope.hub.utils.utils import get_endpoint
+from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms
+from modelscope.utils.config_ds import MS_DATASETS_CACHE
+from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+config.HF_ENDPOINT = get_endpoint()
+
+
+file_utils.get_from_cache = get_from_cache_ms
+
+
+def _download(self, url_or_filename: str,
+ download_config: DownloadConfig) -> str:
+ url_or_filename = str(url_or_filename)
+ # for temp val
+ revision = None
+ if url_or_filename.startswith('hf://'):
+ revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1)
+ if is_relative_path(url_or_filename):
+ # append the relative path to the base_path
+ # url_or_filename = url_or_path_join(self._base_path, url_or_filename)
+ revision = revision or 'master'
+ # Note: make sure the FilePath is the last param
+ params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
+ params: str = urlencode(params)
+ url_or_filename = self._base_path + params
+
+ out = cached_path(url_or_filename, download_config=download_config)
+ out = tracked_str(out)
+ out.set_origin(url_or_filename)
+ return out
+
+
+DownloadManager._download = _download
+
+
+def _dataset_info(
+ self,
+ repo_id: str,
+ *,
+ revision: Optional[str] = None,
+ timeout: Optional[float] = None,
+ files_metadata: bool = False,
+ token: Optional[Union[bool, str]] = None,
+) -> HfDatasetInfo:
+ """
+ Get info on one specific dataset on huggingface.co.
+
+ Dataset can be private if you pass an acceptable token.
+
+ Args:
+ repo_id (`str`):
+ A namespace (user or an organization) and a repo name separated
+ by a `/`.
+ revision (`str`, *optional*):
+ The revision of the dataset repository from which to get the
+ information.
+ timeout (`float`, *optional*):
+ Whether to set a timeout for the request to the Hub.
+ files_metadata (`bool`, *optional*):
+ Whether or not to retrieve metadata for files in the repository
+ (size, LFS metadata, etc). Defaults to `False`.
+ token (`bool` or `str`, *optional*):
+ A valid authentication token (see https://huggingface.co/settings/token).
+ If `None` or `True` and machine is logged in (through `huggingface-cli login`
+ or [`~huggingface_hub.login`]), token will be retrieved from the cache.
+ If `False`, token is not sent in the request header.
+
+ Returns:
+ [`hf_api.DatasetInfo`]: The dataset repository information.
+
+
+
+ Raises the following errors:
+
+ - [`~utils.RepositoryNotFoundError`]
+ If the repository to download from cannot be found. This may be because it doesn't exist,
+ or because it is set to `private` and you do not have access.
+ - [`~utils.RevisionNotFoundError`]
+ If the revision to download from cannot be found.
+
+
+ """
+ _api = HubApi()
+ _namespace, _dataset_name = repo_id.split('/')
+ dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
+ dataset_name=_dataset_name, namespace=_namespace)
+
+ revision: str = revision or 'master'
+ data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
+ revision=revision,
+ files_metadata=files_metadata,
+ timeout=timeout)
+
+ # Parse data
+ data_d: dict = data['Data']
+ data_file_list: list = data_d['Files']
+ # commit_info: dict = data_d['LatestCommitter']
+
+ # Update data # TODO: columns align with HfDatasetInfo
+ data['id'] = repo_id
+ data['private'] = False
+ data['author'] = repo_id.split('/')[0] if repo_id else None
+ data['sha'] = revision
+ data['lastModified'] = None
+ data['gated'] = False
+ data['disabled'] = False
+ data['downloads'] = 0
+ data['likes'] = 0
+ data['tags'] = []
+ data['cardData'] = []
+ data['createdAt'] = None
+
+ # e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}}
+ data['siblings'] = []
+ for file_info_d in data_file_list:
+ file_info = {
+ 'rfilename': file_info_d['Path'],
+ 'blobId': file_info_d['Id'],
+ 'size': file_info_d['Size'],
+ 'type': 'directory' if file_info_d['Type'] == 'tree' else 'file',
+ 'lfs': {
+ 'size': file_info_d['Size'],
+ 'sha256': file_info_d['Sha256'],
+ 'pointerSize': 0
+ }
+ }
+ data['siblings'].append(file_info)
+
+ return HfDatasetInfo(**data)
+
+
+HfApi.dataset_info = _dataset_info
+
+
+def _list_repo_tree(
+ self,
+ repo_id: str,
+ path_in_repo: Optional[str] = None,
+ *,
+ recursive: bool = True,
+ expand: bool = False,
+ revision: Optional[str] = None,
+ repo_type: Optional[str] = None,
+ token: Optional[Union[bool, str]] = None,
+) -> Iterable[Union[RepoFile, RepoFolder]]:
+
+ _api = HubApi()
+
+ if is_relative_path(repo_id) and repo_id.count('/') == 1:
+ _namespace, _dataset_name = repo_id.split('/')
+ elif is_relative_path(repo_id) and repo_id.count('/') == 0:
+ logger.warning(f'Got a relative path: {repo_id} without namespace, '
+ f'Use default namespace: {DEFAULT_DATASET_NAMESPACE}')
+ _namespace, _dataset_name = DEFAULT_DATASET_NAMESPACE, repo_id
+ else:
+ raise ValueError(f'Invalid repo_id: {repo_id} !')
+
+ data: dict = _api.list_repo_tree(dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=revision or 'master',
+ root_path=path_in_repo or None,
+ recursive=True,
+ )
+ # Parse data
+ # Type: 'tree' or 'blob'
+ data_d: dict = data['Data']
+ data_file_list: list = data_d['Files']
+ # commit_info: dict = data_d['LatestCommitter']
+
+ for file_info_d in data_file_list:
+ path_info = {}
+ path_info[
+ 'type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
+ path_info['path'] = file_info_d['Path']
+ path_info['size'] = file_info_d['Size']
+ path_info['oid'] = file_info_d['Sha256']
+
+ yield RepoFile(
+ **path_info) if path_info['type'] == 'file' else RepoFolder(
+ **path_info)
+
+
+HfApi.list_repo_tree = _list_repo_tree
+
+
+def _get_paths_info(
+ self,
+ repo_id: str,
+ paths: Union[List[str], str],
+ *,
+ expand: bool = False,
+ revision: Optional[str] = None,
+ repo_type: Optional[str] = None,
+ token: Optional[Union[bool, str]] = None,
+) -> List[Union[RepoFile, RepoFolder]]:
+
+ _api = HubApi()
+ _namespace, _dataset_name = repo_id.split('/')
+ dataset_hub_id, dataset_type = _api.get_dataset_id_and_type(
+ dataset_name=_dataset_name, namespace=_namespace)
+
+ revision: str = revision or 'master'
+ data = _api.get_dataset_infos(dataset_hub_id=dataset_hub_id,
+ revision=revision,
+ files_metadata=False,
+ recursive='False')
+ data_d: dict = data['Data']
+ data_file_list: list = data_d['Files']
+
+ return [
+ RepoFile(path=item_d['Name'],
+ size=item_d['Size'],
+ oid=item_d['Revision'],
+ lfs=None, # TODO: lfs type to be supported
+ last_commit=None, # TODO: lfs type to be supported
+ security=None
+ ) for item_d in data_file_list if item_d['Name'] == 'README.md'
+ ]
+
+
+HfApi.get_paths_info = _get_paths_info
+
+
+def get_fs_token_paths(
+ urlpath,
+ storage_options=None,
+ protocol=None,
+):
+ if isinstance(urlpath, (list, tuple, set)):
+ if not urlpath:
+ raise ValueError('empty urlpath sequence')
+ urlpath0 = stringify_path(list(urlpath)[0])
+ else:
+ urlpath0 = stringify_path(urlpath)
+ storage_options = storage_options or {}
+ if protocol:
+ storage_options['protocol'] = protocol
+ chain = _un_chain(urlpath0, storage_options or {})
+ inkwargs = {}
+ # Reverse iterate the chain, creating a nested target_* structure
+ for i, ch in enumerate(reversed(chain)):
+ urls, nested_protocol, kw = ch
+ if i == len(chain) - 1:
+ inkwargs = dict(**kw, **inkwargs)
+ continue
+ inkwargs['target_options'] = dict(**kw, **inkwargs)
+ inkwargs['target_protocol'] = nested_protocol
+ inkwargs['fo'] = urls
+ paths, protocol, _ = chain[0]
+ fs = filesystem(protocol, **inkwargs)
+
+ return fs
+
+
+def _resolve_pattern(
+ pattern: str,
+ base_path: str,
+ allowed_extensions: Optional[List[str]] = None,
+ download_config: Optional[DownloadConfig] = None,
+) -> List[str]:
+ """
+ Resolve the paths and URLs of the data files from the pattern passed by the user.
+
+ You can use patterns to resolve multiple local files. Here are a few examples:
+ - *.csv to match all the CSV files at the first level
+ - **.csv to match all the CSV files at any level
+ - data/* to match all the files inside "data"
+ - data/** to match all the files inside "data" and its subdirectories
+
+ The patterns are resolved using the fsspec glob.
+
+ glob.glob, Path.glob, Path.match or fnmatch do not support ** with a prefix/suffix other than a forward slash /.
+ For instance, this means **.json is the same as *.json. On the contrary, the fsspec glob has no limits regarding the ** prefix/suffix, # noqa: E501
+ resulting in **.json being equivalent to **/*.json.
+
+ More generally:
+ - '*' matches any character except a forward-slash (to match just the file or directory name)
+ - '**' matches any character including a forward-slash /
+
+ Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
+ The same applies to special directories that start with a double underscore like "__pycache__".
+ You can still include one if the pattern explicilty mentions it:
+ - to include a hidden file: "*/.hidden.txt" or "*/.*"
+ - to include a hidden directory: ".hidden/*" or ".*/*"
+ - to include a special directory: "__special__/*" or "__*/*"
+
+ Example::
+
+ >>> from datasets.data_files import resolve_pattern
+ >>> base_path = "."
+ >>> resolve_pattern("docs/**/*.py", base_path)
+ [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
+
+ Args:
+ pattern (str): Unix pattern or paths or URLs of the data files to resolve.
+ The paths can be absolute or relative to base_path.
+ Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
+ base_path (str): Base path to use when resolving relative paths.
+ allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
+ For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
+ Returns:
+ List[str]: List of paths or URLs to the local or remote files that match the patterns.
+ """
+ if is_relative_path(pattern):
+ pattern = xjoin(base_path, pattern)
+ elif is_local_path(pattern):
+ base_path = os.path.splitdrive(pattern)[0] + os.sep
+ else:
+ base_path = ''
+ # storage_options: {'hf': {'token': None, 'endpoint': 'https://huggingface.co'}}
+ pattern, storage_options = _prepare_path_and_storage_options(
+ pattern, download_config=download_config)
+ fs = get_fs_token_paths(pattern, storage_options=storage_options)
+ fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker
+ fs_pattern = pattern.split('::')[0].split('://')[-1]
+ files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
+ protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
+ protocol_prefix = protocol + '://' if protocol != 'file' else ''
+ glob_kwargs = {}
+ if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'):
+ # 10 times faster glob with detail=True (ignores costly info like lastCommit)
+ glob_kwargs['expand_info'] = False
+
+ tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
+
+ matched_paths = [
+ filepath if filepath.startswith(protocol_prefix) else protocol_prefix
+ + filepath for filepath, info in tmp_file_paths.items()
+ if info['type'] == 'file' and (
+ xbasename(filepath) not in files_to_ignore)
+ and not _is_inside_unrequested_special_dir(
+ os.path.relpath(filepath, fs_base_path),
+ os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504
+ not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504
+ os.path.relpath(filepath, fs_base_path),
+ os.path.relpath(fs_pattern, fs_base_path))
+ ] # ignore .ipynb and __pycache__, but keep /../
+ if allowed_extensions is not None:
+ out = [
+ filepath for filepath in matched_paths
+ if any('.' + suffix in allowed_extensions
+ for suffix in xbasename(filepath).split('.')[1:])
+ ]
+ if len(out) < len(matched_paths):
+ invalid_matched_files = list(set(matched_paths) - set(out))
+ logger.info(
+ f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: "
+ f'{invalid_matched_files}')
+ else:
+ out = matched_paths
+ if not out:
+ error_msg = f"Unable to find '{pattern}'"
+ if allowed_extensions is not None:
+ error_msg += f' with any supported extension {list(allowed_extensions)}'
+ raise FileNotFoundError(error_msg)
+ return out
+
+
+data_files.resolve_pattern = _resolve_pattern
+
+
+def _get_data_patterns(
+ base_path: str,
+ download_config: Optional[DownloadConfig] = None) -> Dict[str,
+ List[str]]:
+ """
+ Get the default pattern from a directory testing all the supported patterns.
+ The first patterns to return a non-empty list of data files is returned.
+
+ Some examples of supported patterns:
+
+ Input:
+
+ my_dataset_repository/
+ ├── README.md
+ └── dataset.csv
+
+ Output:
+
+ {"train": ["**"]}
+
+ Input:
+
+ my_dataset_repository/
+ ├── README.md
+ ├── train.csv
+ └── test.csv
+
+ my_dataset_repository/
+ ├── README.md
+ └── data/
+ ├── train.csv
+ └── test.csv
+
+ my_dataset_repository/
+ ├── README.md
+ ├── train_0.csv
+ ├── train_1.csv
+ ├── train_2.csv
+ ├── train_3.csv
+ ├── test_0.csv
+ └── test_1.csv
+
+ Output:
+
+ {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
+ 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
+ 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
+ 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
+
+ Input:
+
+ my_dataset_repository/
+ ├── README.md
+ └── data/
+ ├── train/
+ │ ├── shard_0.csv
+ │ ├── shard_1.csv
+ │ ├── shard_2.csv
+ │ └── shard_3.csv
+ └── test/
+ ├── shard_0.csv
+ └── shard_1.csv
+
+ Output:
+
+ {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
+ 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
+ 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
+ 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
+
+ Input:
+
+ my_dataset_repository/
+ ├── README.md
+ └── data/
+ ├── train-00000-of-00003.csv
+ ├── train-00001-of-00003.csv
+ ├── train-00002-of-00003.csv
+ ├── test-00000-of-00001.csv
+ ├── random-00000-of-00003.csv
+ ├── random-00001-of-00003.csv
+ └── random-00002-of-00003.csv
+
+ Output:
+
+ {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
+ 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
+ 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
+
+ In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
+ """
+ resolver = partial(
+ _resolve_pattern, base_path=base_path, download_config=download_config)
+ try:
+ return _get_data_files_patterns(resolver)
+ except FileNotFoundError:
+ raise EmptyDatasetError(
+ f"The directory at {base_path} doesn't contain any data files"
+ ) from None
+
+
+def get_module_without_script(self) -> DatasetModule:
+ _ms_api = HubApi()
+ _repo_id: str = self.name
+ _namespace, _dataset_name = _repo_id.split('/')
+
+ # hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
+ # self.name,
+ # revision=self.revision,
+ # token=self.download_config.token,
+ # timeout=100.0,
+ # )
+ # even if metadata_configs is not None (which means that we will resolve files for each config later)
+ # we cannot skip resolving all files because we need to infer module name by files extensions
+ # revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime
+ revision = self.revision or 'master'
+ base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
+ '/')
+
+ download_config = self.download_config.copy()
+ if download_config.download_desc is None:
+ download_config.download_desc = 'Downloading readme'
+ try:
+ url_or_filename = _ms_api.get_dataset_file_url(
+ file_name='README.md',
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=revision,
+ extension_filter=False,
+ )
+
+ dataset_readme_path = cached_path(
+ url_or_filename=url_or_filename, download_config=download_config)
+ dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data
+ except FileNotFoundError:
+ dataset_card_data = DatasetCardData()
+
+ subset_name: str = download_config.storage_options.get('name', None)
+
+ metadata_configs = MetadataConfigs.from_dataset_card_data(
+ dataset_card_data)
+ dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
+ # we need a set of data files to find which dataset builder to use
+ # because we need to infer module name by files extensions
+ if self.data_files is not None:
+ patterns = sanitize_patterns(self.data_files)
+ elif metadata_configs and 'data_files' in next(
+ iter(metadata_configs.values())):
+
+ if subset_name is not None:
+ subset_data_files = metadata_configs[subset_name]['data_files']
+ else:
+ subset_data_files = next(iter(metadata_configs.values()))['data_files']
+ patterns = sanitize_patterns(subset_data_files)
+ else:
+ patterns = _get_data_patterns(
+ base_path, download_config=self.download_config)
+
+ data_files = DataFilesDict.from_patterns(
+ patterns,
+ base_path=base_path,
+ allowed_extensions=ALL_ALLOWED_EXTENSIONS,
+ download_config=self.download_config,
+ )
+ module_name, default_builder_kwargs = infer_module_for_data_files(
+ data_files=data_files,
+ path=self.name,
+ download_config=self.download_config,
+ )
+ data_files = data_files.filter_extensions(
+ _MODULE_TO_EXTENSIONS[module_name])
+ # Collect metadata files if the module supports them
+ supports_metadata = module_name in _MODULE_SUPPORTS_METADATA
+ if self.data_files is None and supports_metadata:
+ try:
+ metadata_patterns = get_metadata_patterns(
+ base_path, download_config=self.download_config)
+ except FileNotFoundError:
+ metadata_patterns = None
+ if metadata_patterns is not None:
+ metadata_data_files_list = DataFilesList.from_patterns(
+ metadata_patterns,
+ download_config=self.download_config,
+ base_path=base_path)
+ if metadata_data_files_list:
+ data_files = DataFilesDict({
+ split: data_files_list + metadata_data_files_list
+ for split, data_files_list in data_files.items()
+ })
+
+ module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
+
+ if metadata_configs:
+ builder_configs, default_config_name = create_builder_configs_from_metadata_configs(
+ module_path,
+ metadata_configs,
+ base_path=base_path,
+ supports_metadata=supports_metadata,
+ default_builder_kwargs=default_builder_kwargs,
+ download_config=self.download_config,
+ )
+ else:
+ builder_configs: List[BuilderConfig] = [
+ import_main_class(module_path).BUILDER_CONFIG_CLASS(
+ data_files=data_files,
+ **default_builder_kwargs,
+ )
+ ]
+ default_config_name = None
+ builder_kwargs = {
+ # "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
+ 'base_path':
+ _ms_api.get_file_base_path(
+ namespace=_namespace,
+ dataset_name=_dataset_name,
+ ),
+ 'repo_id':
+ self.name,
+ 'dataset_name':
+ camelcase_to_snakecase(Path(self.name).name),
+ 'data_files': data_files,
+ }
+ download_config = self.download_config.copy()
+ if download_config.download_desc is None:
+ download_config.download_desc = 'Downloading metadata'
+
+ # Note: `dataset_infos.json` is deprecated and can cause an error during loading if it exists
+
+ if default_config_name is None and len(dataset_infos) == 1:
+ default_config_name = next(iter(dataset_infos))
+
+ hash = revision
+ return DatasetModule(
+ module_path,
+ hash,
+ builder_kwargs,
+ dataset_infos=dataset_infos,
+ builder_configs_parameters=BuilderConfigsParameters(
+ metadata_configs=metadata_configs,
+ builder_configs=builder_configs,
+ default_config_name=default_config_name,
+ ),
+ )
+
+
+HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
+
+
+def _download_additional_modules(
+ name: str,
+ dataset_name: str,
+ namespace: str,
+ revision: str,
+ imports: Tuple[str, str, str, str],
+ download_config: Optional[DownloadConfig]
+) -> List[Tuple[str, str]]:
+ """
+ Download additional module for a module .py at URL (or local path) /.py
+ The imports must have been parsed first using ``get_imports``.
+
+ If some modules need to be installed with pip, an error is raised showing how to install them.
+ This function return the list of downloaded modules as tuples (import_name, module_file_path).
+
+ The downloaded modules can then be moved into an importable directory
+ with ``_copy_script_and_other_resources_in_importable_dir``.
+ """
+ local_imports = []
+ library_imports = []
+ download_config = download_config.copy()
+ if download_config.download_desc is None:
+ download_config.download_desc = 'Downloading extra modules'
+ for import_type, import_name, import_path, sub_directory in imports:
+ if import_type == 'library':
+ library_imports.append((import_name, import_path)) # Import from a library
+ continue
+
+ if import_name == name:
+ raise ValueError(
+ f'Error in the {name} script, importing relative {import_name} module '
+ f'but {import_name} is the name of the script. '
+ f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
+ f'comment pointing to the original relative import file path.'
+ )
+ if import_type == 'internal':
+ _api = HubApi()
+ # url_or_filename = url_or_path_join(base_path, import_path + ".py")
+ file_name = import_path + '.py'
+ url_or_filename = _api.get_dataset_file_url(file_name=file_name,
+ dataset_name=dataset_name,
+ namespace=namespace,
+ revision=revision,)
+ elif import_type == 'external':
+ url_or_filename = import_path
+ else:
+ raise ValueError('Wrong import_type')
+
+ local_import_path = cached_path(
+ url_or_filename,
+ download_config=download_config,
+ )
+ if sub_directory is not None:
+ local_import_path = os.path.join(local_import_path, sub_directory)
+ local_imports.append((import_name, local_import_path))
+
+ # Check library imports
+ needs_to_be_installed = {}
+ for library_import_name, library_import_path in library_imports:
+ try:
+ lib = importlib.import_module(library_import_name) # noqa F841
+ except ImportError:
+ if library_import_name not in needs_to_be_installed or library_import_path != library_import_name:
+ needs_to_be_installed[library_import_name] = library_import_path
+ if needs_to_be_installed:
+ _dependencies_str = 'dependencies' if len(needs_to_be_installed) > 1 else 'dependency'
+ _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it'
+ if 'sklearn' in needs_to_be_installed.keys():
+ needs_to_be_installed['sklearn'] = 'scikit-learn'
+ if 'Bio' in needs_to_be_installed.keys():
+ needs_to_be_installed['Bio'] = 'biopython'
+ raise ImportError(
+ f'To be able to use {name}, you need to install the following {_dependencies_str}: '
+ f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install "
+ f"{' '.join(needs_to_be_installed.values())}' for instance."
+ )
+ return local_imports
+
+
+def get_module_with_script(self) -> DatasetModule:
+ if config.HF_DATASETS_TRUST_REMOTE_CODE and self.trust_remote_code is None:
+ warnings.warn(
+ f'The repository for {self.name} contains custom code which must be executed to correctly '
+ f'load the dataset. You can inspect the repository content at https://hf.co/datasets/{self.name}\n'
+ f'You can avoid this message in future by passing the argument `trust_remote_code=True`.\n'
+ f'Passing `trust_remote_code=True` will be mandatory '
+ f'to load this dataset from the next major release of `datasets`.',
+ FutureWarning,
+ )
+ # get script and other files
+ # local_path = self.download_loading_script()
+ # dataset_infos_path = self.download_dataset_infos_file()
+ # dataset_readme_path = self.download_dataset_readme_file()
+
+ _api = HubApi()
+ _dataset_name: str = self.name.split('/')[-1]
+ _namespace: str = self.name.split('/')[0]
+
+ script_file_name = f'{_dataset_name}.py'
+ script_url: str = _api.get_dataset_file_url(
+ file_name=script_file_name,
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=self.revision,
+ extension_filter=False,
+ )
+ local_script_path = cached_path(
+ url_or_filename=script_url, download_config=self.download_config)
+
+ dataset_infos_path = None
+ # try:
+ # dataset_infos_url: str = _api.get_dataset_file_url(
+ # file_name='dataset_infos.json',
+ # dataset_name=_dataset_name,
+ # namespace=_namespace,
+ # revision=self.revision,
+ # extension_filter=False,
+ # )
+ # dataset_infos_path = cached_path(
+ # url_or_filename=dataset_infos_url, download_config=self.download_config)
+ # except Exception as e:
+ # logger.info(f'Cannot find dataset_infos.json: {e}')
+ # dataset_infos_path = None
+
+ dataset_readme_url: str = _api.get_dataset_file_url(
+ file_name='README.md',
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=self.revision,
+ extension_filter=False,
+ )
+ dataset_readme_path = cached_path(
+ url_or_filename=dataset_readme_url, download_config=self.download_config)
+
+ imports = get_imports(local_script_path)
+ local_imports = _download_additional_modules(
+ name=self.name,
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=self.revision,
+ imports=imports,
+ download_config=self.download_config,
+ )
+ additional_files = []
+ if dataset_infos_path:
+ additional_files.append((config.DATASETDICT_INFOS_FILENAME, dataset_infos_path))
+ if dataset_readme_path:
+ additional_files.append((config.REPOCARD_FILENAME, dataset_readme_path))
+ # copy the script and the files in an importable directory
+ dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
+ hash = files_to_hash([local_script_path] + [loc[1] for loc in local_imports])
+ importable_file_path = _get_importable_file_path(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash,
+ name=self.name,
+ )
+ if not os.path.exists(importable_file_path):
+ trust_remote_code = resolve_trust_remote_code(self.trust_remote_code, self.name)
+ if trust_remote_code:
+ _create_importable_file(
+ local_path=local_script_path,
+ local_imports=local_imports,
+ additional_files=additional_files,
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash,
+ name=self.name,
+ download_mode=self.download_mode,
+ )
+ else:
+ raise ValueError(
+ f'Loading {self.name} requires you to execute the dataset script in that'
+ ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
+ ' set the option `trust_remote_code=True` to remove this error.'
+ )
+ module_path, hash = _load_importable_file(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash,
+ name=self.name,
+ )
+ # make the new module to be noticed by the import system
+ importlib.invalidate_caches()
+ builder_kwargs = {
+ # "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
+ 'base_path': _api.get_file_base_path(namespace=_namespace, dataset_name=_dataset_name),
+ 'repo_id': self.name,
+ }
+ return DatasetModule(module_path, hash, builder_kwargs)
+
+
+HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
+
+
+class DatasetsWrapperHF:
+
+ @staticmethod
+ def load_dataset(
+ path: str,
+ name: Optional[str] = None,
+ data_dir: Optional[str] = None,
+ data_files: Optional[Union[str, Sequence[str],
+ Mapping[str, Union[str,
+ Sequence[str]]]]] = None,
+ split: Optional[Union[str, Split]] = None,
+ cache_dir: Optional[str] = None,
+ features: Optional[Features] = None,
+ download_config: Optional[DownloadConfig] = None,
+ download_mode: Optional[Union[DownloadMode, str]] = None,
+ verification_mode: Optional[Union[VerificationMode, str]] = None,
+ ignore_verifications='deprecated',
+ keep_in_memory: Optional[bool] = None,
+ save_infos: bool = False,
+ revision: Optional[Union[str, Version]] = None,
+ token: Optional[Union[bool, str]] = None,
+ use_auth_token='deprecated',
+ task='deprecated',
+ streaming: bool = False,
+ num_proc: Optional[int] = None,
+ storage_options: Optional[Dict] = None,
+ trust_remote_code: bool = None,
+ dataset_info_only: Optional[bool] = False,
+ **config_kwargs,
+ ) -> Union[DatasetDict, Dataset, IterableDatasetDict, IterableDataset,
+ dict]:
+
+ if use_auth_token != 'deprecated':
+ warnings.warn(
+ "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
+ "You can remove this warning by passing 'token=' instead.",
+ FutureWarning,
+ )
+ token = use_auth_token
+ if ignore_verifications != 'deprecated':
+ verification_mode = VerificationMode.NO_CHECKS if ignore_verifications else VerificationMode.ALL_CHECKS
+ warnings.warn(
+ "'ignore_verifications' was deprecated in favor of 'verification_mode' "
+ 'in version 2.9.1 and will be removed in 3.0.0.\n'
+ f"You can remove this warning by passing 'verification_mode={verification_mode.value}' instead.",
+ FutureWarning,
+ )
+ if task != 'deprecated':
+ warnings.warn(
+ "'task' was deprecated in version 2.13.0 and will be removed in 3.0.0.\n",
+ FutureWarning,
+ )
+ else:
+ task = None
+ if data_files is not None and not data_files:
+ raise ValueError(
+ f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default)."
+ )
+ if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(
+ ):
+ raise ValueError(
+ 'You are trying to load a dataset that was saved using `save_to_disk`. '
+ 'Please use `load_from_disk` instead.')
+
+ if streaming and num_proc is not None:
+ raise NotImplementedError(
+ 'Loading a streaming dataset in parallel with `num_proc` is not implemented. '
+ 'To parallelize streaming, you can wrap the dataset with a PyTorch DataLoader '
+ 'using `num_workers` > 1 instead.')
+
+ download_mode = DownloadMode(download_mode
+ or DownloadMode.REUSE_DATASET_IF_EXISTS)
+ verification_mode = VerificationMode((
+ verification_mode or VerificationMode.BASIC_CHECKS
+ ) if not save_infos else VerificationMode.ALL_CHECKS)
+
+ # Create a dataset builder
+ builder_instance = DatasetsWrapperHF.load_dataset_builder(
+ path=path,
+ name=name,
+ data_dir=data_dir,
+ data_files=data_files,
+ cache_dir=cache_dir,
+ features=features,
+ download_config=download_config,
+ download_mode=download_mode,
+ revision=revision,
+ token=token,
+ storage_options=storage_options,
+ trust_remote_code=trust_remote_code,
+ _require_default_config_name=name is None,
+ **config_kwargs,
+ )
+
+ # Note: Only for preview mode
+ if dataset_info_only:
+ ret_dict = {}
+ # Get dataset config info from python script
+ if isinstance(path, str) and path.endswith('.py') and os.path.exists(path):
+ from datasets import get_dataset_config_names
+ subset_list = get_dataset_config_names(path)
+ ret_dict = {_subset: [] for _subset in subset_list}
+ return ret_dict
+
+ if builder_instance is None or not hasattr(builder_instance,
+ 'builder_configs'):
+ logger.error(f'No builder_configs found for {path} dataset.')
+ return ret_dict
+
+ _tmp_builder_configs = builder_instance.builder_configs
+ for tmp_config_name, tmp_builder_config in _tmp_builder_configs.items():
+ tmp_config_name = str(tmp_config_name)
+ if hasattr(tmp_builder_config, 'data_files') and tmp_builder_config.data_files is not None:
+ ret_dict[tmp_config_name] = [str(item) for item in list(tmp_builder_config.data_files.keys())]
+ else:
+ ret_dict[tmp_config_name] = []
+ return ret_dict
+
+ # Return iterable dataset in case of streaming
+ if streaming:
+ return builder_instance.as_streaming_dataset(split=split)
+
+ # Some datasets are already processed on the HF google storage
+ # Don't try downloading from Google storage for the packaged datasets as text, json, csv or pandas
+ # try_from_hf_gcs = path not in _PACKAGED_DATASETS_MODULES
+
+ # Download and prepare data
+ builder_instance.download_and_prepare(
+ download_config=download_config,
+ download_mode=download_mode,
+ verification_mode=verification_mode,
+ try_from_hf_gcs=False,
+ num_proc=num_proc,
+ storage_options=storage_options,
+ # base_path=builder_instance.base_path,
+ # file_format=builder_instance.name or 'arrow',
+ )
+
+ # Build dataset for splits
+ keep_in_memory = (
+ keep_in_memory if keep_in_memory is not None else is_small_dataset(
+ builder_instance.info.dataset_size))
+ ds = builder_instance.as_dataset(
+ split=split,
+ verification_mode=verification_mode,
+ in_memory=keep_in_memory)
+ # Rename and cast features to match task schema
+ if task is not None:
+ # To avoid issuing the same warning twice
+ with warnings.catch_warnings():
+ warnings.simplefilter('ignore', FutureWarning)
+ ds = ds.prepare_for_task(task)
+ if save_infos:
+ builder_instance._save_infos()
+
+ try:
+ _api = HubApi()
+ if is_relative_path(path) and path.count('/') == 1:
+ _namespace, _dataset_name = path.split('/')
+ _api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace)
+ except Exception as e:
+ logger.warning(f'Could not record download statistics: {e}')
+
+ return ds
+
+ @staticmethod
+ def load_dataset_builder(
+ path: str,
+ name: Optional[str] = None,
+ data_dir: Optional[str] = None,
+ data_files: Optional[Union[str, Sequence[str],
+ Mapping[str, Union[str,
+ Sequence[str]]]]] = None,
+ cache_dir: Optional[str] = None,
+ features: Optional[Features] = None,
+ download_config: Optional[DownloadConfig] = None,
+ download_mode: Optional[Union[DownloadMode, str]] = None,
+ revision: Optional[Union[str, Version]] = None,
+ token: Optional[Union[bool, str]] = None,
+ use_auth_token='deprecated',
+ storage_options: Optional[Dict] = None,
+ trust_remote_code: Optional[bool] = None,
+ _require_default_config_name=True,
+ **config_kwargs,
+ ) -> DatasetBuilder:
+
+ if use_auth_token != 'deprecated':
+ warnings.warn(
+ "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
+ "You can remove this warning by passing 'token=' instead.",
+ FutureWarning,
+ )
+ token = use_auth_token
+ download_mode = DownloadMode(download_mode
+ or DownloadMode.REUSE_DATASET_IF_EXISTS)
+ if token is not None:
+ download_config = download_config.copy(
+ ) if download_config else DownloadConfig()
+ download_config.token = token
+ if storage_options is not None:
+ download_config = download_config.copy(
+ ) if download_config else DownloadConfig()
+ download_config.storage_options.update(storage_options)
+
+ dataset_module = DatasetsWrapperHF.dataset_module_factory(
+ path,
+ revision=revision,
+ download_config=download_config,
+ download_mode=download_mode,
+ data_dir=data_dir,
+ data_files=data_files,
+ cache_dir=cache_dir,
+ trust_remote_code=trust_remote_code,
+ _require_default_config_name=_require_default_config_name,
+ _require_custom_configs=bool(config_kwargs),
+ name=name,
+ )
+ # Get dataset builder class from the processing script
+ builder_kwargs = dataset_module.builder_kwargs
+ data_dir = builder_kwargs.pop('data_dir', data_dir)
+ data_files = builder_kwargs.pop('data_files', data_files)
+ config_name = builder_kwargs.pop(
+ 'config_name', name
+ or dataset_module.builder_configs_parameters.default_config_name)
+ dataset_name = builder_kwargs.pop('dataset_name', None)
+ info = dataset_module.dataset_infos.get(
+ config_name) if dataset_module.dataset_infos else None
+
+ if (path in _PACKAGED_DATASETS_MODULES and data_files is None
+ and dataset_module.builder_configs_parameters.
+ builder_configs[0].data_files is None):
+ error_msg = f'Please specify the data files or data directory to load for the {path} dataset builder.'
+ example_extensions = [
+ extension for extension in _EXTENSION_TO_MODULE
+ if _EXTENSION_TO_MODULE[extension] == path
+ ]
+ if example_extensions:
+ error_msg += f'\nFor example `data_files={{"train": "path/to/data/train/*.{example_extensions[0]}"}}`'
+ raise ValueError(error_msg)
+
+ builder_cls = get_dataset_builder_class(
+ dataset_module, dataset_name=dataset_name)
+
+ builder_instance: DatasetBuilder = builder_cls(
+ cache_dir=cache_dir,
+ dataset_name=dataset_name,
+ config_name=config_name,
+ data_dir=data_dir,
+ data_files=data_files,
+ hash=dataset_module.hash,
+ info=info,
+ features=features,
+ token=token,
+ storage_options=storage_options,
+ **builder_kwargs, # contains base_path
+ **config_kwargs,
+ )
+ builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
+
+ return builder_instance
+
+ @staticmethod
+ def dataset_module_factory(
+ path: str,
+ revision: Optional[Union[str, Version]] = None,
+ download_config: Optional[DownloadConfig] = None,
+ download_mode: Optional[Union[DownloadMode, str]] = None,
+ dynamic_modules_path: Optional[str] = None,
+ data_dir: Optional[str] = None,
+ data_files: Optional[Union[Dict, List, str, DataFilesDict]] = None,
+ cache_dir: Optional[str] = None,
+ trust_remote_code: Optional[bool] = None,
+ _require_default_config_name=True,
+ _require_custom_configs=False,
+ **download_kwargs,
+ ) -> DatasetModule:
+
+ subset_name: str = download_kwargs.pop('name', None)
+ if download_config is None:
+ download_config = DownloadConfig(**download_kwargs)
+ download_config.storage_options.update({'name': subset_name})
+
+ if download_config and download_config.cache_dir is None:
+ download_config.cache_dir = MS_DATASETS_CACHE
+
+ download_mode = DownloadMode(download_mode
+ or DownloadMode.REUSE_DATASET_IF_EXISTS)
+ download_config.extract_compressed_file = True
+ download_config.force_extract = True
+ download_config.force_download = download_mode == DownloadMode.FORCE_REDOWNLOAD
+
+ filename = list(
+ filter(lambda x: x,
+ path.replace(os.sep, '/').split('/')))[-1]
+ if not filename.endswith('.py'):
+ filename = filename + '.py'
+ combined_path = os.path.join(path, filename)
+
+ # We have several ways to get a dataset builder:
+ #
+ # - if path is the name of a packaged dataset module
+ # -> use the packaged module (json, csv, etc.)
+ #
+ # - if os.path.join(path, name) is a local python file
+ # -> use the module from the python file
+ # - if path is a local directory (but no python file)
+ # -> use a packaged module (csv, text etc.) based on content of the directory
+ #
+ # - if path has one "/" and is dataset repository on the HF hub with a python file
+ # -> the module from the python file in the dataset repository
+ # - if path has one "/" and is dataset repository on the HF hub without a python file
+ # -> use a packaged module (csv, text etc.) based on content of the repository
+
+ # Try packaged
+ if path in _PACKAGED_DATASETS_MODULES:
+ return PackagedDatasetModuleFactory(
+ path,
+ data_dir=data_dir,
+ data_files=data_files,
+ download_config=download_config,
+ download_mode=download_mode,
+ ).get_module()
+ # Try locally
+ elif path.endswith(filename):
+ if os.path.isfile(path):
+ return LocalDatasetModuleFactoryWithScript(
+ path,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
+ else:
+ raise FileNotFoundError(
+ f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
+ )
+ elif os.path.isfile(combined_path):
+ return LocalDatasetModuleFactoryWithScript(
+ combined_path,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
+ elif os.path.isdir(path):
+ return LocalDatasetModuleFactoryWithoutScript(
+ path,
+ data_dir=data_dir,
+ data_files=data_files,
+ download_mode=download_mode).get_module()
+ # Try remotely
+ elif is_relative_path(path) and path.count('/') <= 1:
+ try:
+ _raise_if_offline_mode_is_enabled()
+
+ try:
+ dataset_info = HfApi().dataset_info(
+ repo_id=path,
+ revision=revision,
+ token=download_config.token,
+ timeout=100.0,
+ )
+ except Exception as e: # noqa catch any exception of hf_hub and consider that the dataset doesn't exist
+ if isinstance(
+ e,
+ ( # noqa: E131
+ OfflineModeIsEnabled, # noqa: E131
+ requests.exceptions.
+ ConnectTimeout, # noqa: E131, E261
+ requests.exceptions.ConnectionError, # noqa: E131
+ ), # noqa: E131
+ ):
+ raise ConnectionError(
+ f"Couldn't reach '{path}' on the Hub ({type(e).__name__})"
+ )
+ elif '404' in str(e):
+ msg = f"Dataset '{path}' doesn't exist on the Hub"
+ raise DatasetNotFoundError(
+ msg
+ + f" at revision '{revision}'" if revision else msg
+ )
+ elif '401' in str(e):
+ msg = f"Dataset '{path}' doesn't exist on the Hub"
+ msg = msg + f" at revision '{revision}'" if revision else msg
+ raise DatasetNotFoundError(
+ msg + '. If the repo is private or gated, '
+ 'make sure to log in with `huggingface-cli login`.'
+ )
+ else:
+ raise e
+ if filename in [
+ sibling.rfilename for sibling in dataset_info.siblings
+ ]: # contains a dataset script
+
+ # fs = HfFileSystem(
+ # endpoint=config.HF_ENDPOINT,
+ # token=download_config.token)
+
+ # TODO
+ can_load_config_from_parquet_export = False
+ # if _require_custom_configs:
+ # can_load_config_from_parquet_export = False
+ # elif _require_default_config_name:
+ # with fs.open(
+ # f'datasets/{path}/{filename}',
+ # 'r',
+ # revision=revision,
+ # encoding='utf-8') as f:
+ # can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read(
+ # )
+ # else:
+ # can_load_config_from_parquet_export = True
+ if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
+ # If the parquet export is ready (parquet files + info available for the current sha),
+ # we can use it instead
+ # This fails when the dataset has multiple configs and a default config and
+ # the user didn't specify a configuration name (_require_default_config_name=True).
+ try:
+ return HubDatasetModuleFactoryWithParquetExport(
+ path,
+ download_config=download_config,
+ revision=dataset_info.sha).get_module()
+ except _datasets_server.DatasetsServerError:
+ pass
+ # Otherwise we must use the dataset script if the user trusts it
+ return HubDatasetModuleFactoryWithScript(
+ path,
+ revision=revision,
+ download_config=download_config,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
+ else:
+ return HubDatasetModuleFactoryWithoutScript(
+ path,
+ revision=revision,
+ data_dir=data_dir,
+ data_files=data_files,
+ download_config=download_config,
+ download_mode=download_mode,
+ ).get_module()
+ except Exception as e1:
+ # All the attempts failed, before raising the error we should check if the module is already cached
+ try:
+ return CachedDatasetModuleFactory(
+ path,
+ dynamic_modules_path=dynamic_modules_path,
+ cache_dir=cache_dir).get_module()
+ except Exception:
+ # If it's not in the cache, then it doesn't exist.
+ if isinstance(e1, OfflineModeIsEnabled):
+ raise ConnectionError(
+ f"Couldn't reach the Hugging Face Hub for dataset '{path}': {e1}"
+ ) from None
+ if isinstance(e1,
+ (DataFilesNotFoundError,
+ DatasetNotFoundError, EmptyDatasetError)):
+ raise e1 from None
+ if isinstance(e1, FileNotFoundError):
+ raise FileNotFoundError(
+ f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
+ f'any data file in the same directory. '
+ f"Couldn't find '{path}' on the Hugging Face Hub either: {type(e1).__name__}: {e1}"
+ ) from None
+ raise e1 from None
+ else:
+ raise FileNotFoundError(
+ f"Couldn't find a dataset script at {relative_to_absolute_path(combined_path)} or "
+ f'any data file in the same directory.')
+
+
+load_dataset = DatasetsWrapperHF.load_dataset
diff --git a/modelscope/msdatasets/utils/hf_file_utils.py b/modelscope/msdatasets/utils/hf_file_utils.py
new file mode 100644
index 00000000..fea2506a
--- /dev/null
+++ b/modelscope/msdatasets/utils/hf_file_utils.py
@@ -0,0 +1,237 @@
+# noqa: isort:skip_file, yapf: disable
+# Copyright (c) Alibaba, Inc. and its affiliates.
+# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
+
+import json
+import os
+import re
+import shutil
+import warnings
+from contextlib import contextmanager
+from functools import partial
+from pathlib import Path
+from urllib.parse import urljoin, urlparse
+import requests
+
+from datasets import config
+from datasets.utils.file_utils import hash_url_to_filename, get_authentication_headers_for_url, ftp_head, fsspec_head, \
+ http_head, _raise_if_offline_mode_is_enabled, ftp_get, fsspec_get, http_get
+from filelock import FileLock
+
+from modelscope.utils.config_ds import MS_DATASETS_CACHE
+from modelscope.utils.logger import get_logger
+from modelscope.hub.api import HubApi, ModelScopeConfig
+
+logger = get_logger()
+
+
+def get_from_cache_ms(
+ url,
+ cache_dir=None,
+ force_download=False,
+ proxies=None,
+ etag_timeout=100,
+ resume_download=False,
+ user_agent=None,
+ local_files_only=False,
+ use_etag=True,
+ max_retries=0,
+ token=None,
+ use_auth_token='deprecated',
+ ignore_url_params=False,
+ storage_options=None,
+ download_desc=None,
+) -> str:
+ """
+ Given a URL, look for the corresponding file in the local cache.
+ If it's not there, download it. Then return the path to the cached file.
+
+ Return:
+ Local path (string)
+
+ Raises:
+ FileNotFoundError: in case of non-recoverable file
+ (non-existent or no cache on disk)
+ ConnectionError: in case of unreachable url
+ and no cache on disk
+ """
+ if use_auth_token != 'deprecated':
+ warnings.warn(
+ "'use_auth_token' was deprecated in favor of 'token' in version 2.14.0 and will be removed in 3.0.0.\n"
+ f"You can remove this warning by passing 'token={use_auth_token}' instead.",
+ FutureWarning,
+ )
+ token = use_auth_token
+ if cache_dir is None:
+ cache_dir = MS_DATASETS_CACHE
+ if isinstance(cache_dir, Path):
+ cache_dir = str(cache_dir)
+
+ os.makedirs(cache_dir, exist_ok=True)
+
+ if ignore_url_params:
+ # strip all query parameters and #fragments from the URL
+ cached_url = urljoin(url, urlparse(url).path)
+ else:
+ cached_url = url # additional parameters may be added to the given URL
+
+ connected = False
+ response = None
+ cookies = None
+ etag = None
+ head_error = None
+ scheme = None
+
+ # Try a first time to file the file on the local file system without eTag (None)
+ # if we don't ask for 'force_download' then we spare a request
+ filename = hash_url_to_filename(cached_url, etag=None)
+ cache_path = os.path.join(cache_dir, filename)
+
+ if os.path.exists(cache_path) and not force_download and not use_etag:
+ return cache_path
+
+ # Prepare headers for authentication
+ headers = get_authentication_headers_for_url(url, token=token)
+ if user_agent is not None:
+ headers['user-agent'] = user_agent
+
+ # We don't have the file locally or we need an eTag
+ if not local_files_only:
+ scheme = urlparse(url).scheme
+ if scheme == 'ftp':
+ connected = ftp_head(url)
+ elif scheme not in ('http', 'https'):
+ response = fsspec_head(url, storage_options=storage_options)
+ # s3fs uses "ETag", gcsfs uses "etag"
+ etag = (response.get('ETag', None) or response.get('etag', None)) if use_etag else None
+ connected = True
+ try:
+ cookies = ModelScopeConfig.get_cookies()
+ response = http_head(
+ url,
+ allow_redirects=True,
+ proxies=proxies,
+ timeout=etag_timeout,
+ max_retries=max_retries,
+ headers=headers,
+ cookies=cookies,
+ )
+ if response.status_code == 200: # ok
+ etag = response.headers.get('ETag') if use_etag else None
+ for k, v in response.cookies.items():
+ # In some edge cases, we need to get a confirmation token
+ if k.startswith('download_warning') and 'drive.google.com' in url:
+ url += '&confirm=' + v
+ cookies = response.cookies
+ connected = True
+ # Fix Google Drive URL to avoid Virus scan warning
+ if 'drive.google.com' in url and 'confirm=' not in url:
+ url += '&confirm=t'
+ # In some edge cases, head request returns 400 but the connection is actually ok
+ elif (
+ (response.status_code == 400 and 'firebasestorage.googleapis.com' in url)
+ or (response.status_code == 405 and 'drive.google.com' in url)
+ or (
+ response.status_code == 403
+ and (
+ re.match(r'^https?://github.com/.*?/.*?/releases/download/.*?/.*?$', url)
+ or re.match(r'^https://.*?s3.*?amazonaws.com/.*?$', response.url)
+ )
+ )
+ or (response.status_code == 403 and 'ndownloader.figstatic.com' in url)
+ ):
+ connected = True
+ logger.info(f"Couldn't get ETag version for url {url}")
+ elif response.status_code == 401 and config.HF_ENDPOINT in url and token is None:
+ raise ConnectionError(
+ f'Unauthorized for URL {url}. '
+ f'Please use the parameter `token=True` after logging in with `huggingface-cli login`'
+ )
+ except (OSError, requests.exceptions.Timeout) as e:
+ # not connected
+ head_error = e
+ pass
+
+ # connected == False = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
+ # try to get the last downloaded one
+ if not connected:
+ if os.path.exists(cache_path) and not force_download:
+ return cache_path
+ if local_files_only:
+ raise FileNotFoundError(
+ f'Cannot find the requested files in the cached path at {cache_path} and outgoing traffic has been'
+ " disabled. To enable file online look-ups, set 'local_files_only' to False."
+ )
+ elif response is not None and response.status_code == 404:
+ raise FileNotFoundError(f"Couldn't find file at {url}")
+ _raise_if_offline_mode_is_enabled(f'Tried to reach {url}')
+ if head_error is not None:
+ raise ConnectionError(f"Couldn't reach {url} ({repr(head_error)})")
+ elif response is not None:
+ raise ConnectionError(f"Couldn't reach {url} (error {response.status_code})")
+ else:
+ raise ConnectionError(f"Couldn't reach {url}")
+
+ # Try a second time
+ filename = hash_url_to_filename(cached_url, etag)
+ cache_path = os.path.join(cache_dir, filename)
+
+ if os.path.exists(cache_path) and not force_download:
+ return cache_path
+
+ # From now on, connected is True.
+ # Prevent parallel downloads of the same file with a lock.
+ lock_path = cache_path + '.lock'
+ with FileLock(lock_path):
+ # Retry in case previously locked processes just enter after the precedent process releases the lock
+ if os.path.exists(cache_path) and not force_download:
+ return cache_path
+
+ incomplete_path = cache_path + '.incomplete'
+
+ @contextmanager
+ def temp_file_manager(mode='w+b'):
+ with open(incomplete_path, mode) as f:
+ yield f
+
+ resume_size = 0
+ if resume_download:
+ temp_file_manager = partial(temp_file_manager, mode='a+b')
+ if os.path.exists(incomplete_path):
+ resume_size = os.stat(incomplete_path).st_size
+
+ # Download to temporary file, then copy to cache path once finished.
+ # Otherwise, you get corrupt cache entries if the download gets interrupted.
+ with temp_file_manager() as temp_file:
+ logger.info(f'Downloading to {temp_file.name}')
+
+ # GET file object
+ if scheme == 'ftp':
+ ftp_get(url, temp_file)
+ elif scheme not in ('http', 'https'):
+ fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
+ else:
+ http_get(
+ url,
+ temp_file=temp_file,
+ proxies=proxies,
+ resume_size=resume_size,
+ headers=headers,
+ cookies=cookies,
+ max_retries=max_retries,
+ desc=download_desc,
+ )
+
+ logger.info(f'storing {url} in cache at {cache_path}')
+ shutil.move(temp_file.name, cache_path)
+ umask = os.umask(0o666)
+ os.umask(umask)
+ os.chmod(cache_path, 0o666 & ~umask)
+
+ logger.info(f'creating metadata file for {cache_path}')
+ meta = {'url': url, 'etag': etag}
+ meta_path = cache_path + '.json'
+ with open(meta_path, 'w', encoding='utf-8') as meta_file:
+ json.dump(meta, meta_file)
+
+ return cache_path
diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py
index 9921b826..62a8dbd7 100644
--- a/modelscope/utils/constant.py
+++ b/modelscope/utils/constant.py
@@ -393,9 +393,14 @@ class DatasetFormations(enum.Enum):
# formation that is compatible with official huggingface dataset, which
# organizes whole dataset into one single (zip) file.
hf_compatible = 1
+
# native modelscope formation that supports, among other things,
# multiple files in a dataset
native = 2
+
+ # general formation for datasets
+ general = 4
+
# for local meta cache mark
formation_mark_ext = '.formation_mark'
@@ -403,6 +408,7 @@ class DatasetFormations(enum.Enum):
DatasetMetaFormats = {
DatasetFormations.native: ['.json'],
DatasetFormations.hf_compatible: ['.py'],
+ DatasetFormations.general: ['.py'],
}
diff --git a/requirements/framework.txt b/requirements/framework.txt
index 8804fe8c..d4987429 100644
--- a/requirements/framework.txt
+++ b/requirements/framework.txt
@@ -4,6 +4,7 @@ datasets>=2.14.5
einops
filelock>=3.3.0
gast>=0.2.2
+huggingface_hub
numpy
oss2
pandas
diff --git a/tests/msdatasets/test_general_datasets.py b/tests/msdatasets/test_general_datasets.py
new file mode 100644
index 00000000..21ba3f2b
--- /dev/null
+++ b/tests/msdatasets/test_general_datasets.py
@@ -0,0 +1,103 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+import unittest
+
+from modelscope import MsDataset
+from modelscope.utils.logger import get_logger
+from modelscope.utils.test_utils import test_level
+
+logger = get_logger()
+
+# Note: MODELSCOPE_DOMAIN is set to 'test.modelscope.cn' in the environment variable
+# TODO: ONLY FOR TEST ENVIRONMENT, to be replaced by the online domain
+
+TEST_INNER_LEVEL = 1
+
+
+class GeneralMsDatasetTest(unittest.TestCase):
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_return_dataset_info_only(self):
+ ds = MsDataset.load(
+ 'wangxingjun778test/aya_dataset_mini', dataset_info_only=True)
+ print(f'>>output of test_return_dataset_info_only:\n {ds}')
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_fashion_mnist(self):
+ # inner means the dataset is on the test.modelscope.cn environment
+ ds = MsDataset.load(
+ 'xxxxtest0004/ms_test_0308_py',
+ subset_name='fashion_mnist',
+ split='train')
+ print(f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_clue(self):
+ ds = MsDataset.load(
+ 'wangxingjun778test/clue', subset_name='afqmc', split='train')
+ print(f'>>output of test_inner_clue:\n {next(iter(ds))}')
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_cats_and_dogs_mini(self):
+ ds = MsDataset.load(
+ 'wangxingjun778test/cats_and_dogs_mini', split='train')
+ print(f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_aya_dataset_mini(self):
+ # Dataset Format:
+ # data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
+ # demographics/train-xxx-of-xxx.parquet
+
+ ds = MsDataset.load(
+ 'wangxingjun778test/aya_dataset_mini', split='train')
+ print(f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
+
+ ds = MsDataset.load(
+ 'wangxingjun778test/aya_dataset_mini', subset_name='demographics')
+ assert next(iter(ds['train']))
+ print(
+ f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
+ )
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_no_standard_imgs(self):
+ infos = MsDataset.load(
+ 'xxxxtest0004/png_jpg_txt_test', dataset_info_only=True)
+ assert infos['default']
+
+ ds = MsDataset.load('xxxxtest0004/png_jpg_txt_test', split='train')
+ print(f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
+ assert next(iter(ds))
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_hf_pictures(self):
+ ds = MsDataset.load('xxxxtest0004/hf_Pictures')
+ print(ds)
+ assert next(iter(ds))
+
+ @unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
+ def test_inner_speech_yinpin(self):
+ ds = MsDataset.load('xxxxtest0004/hf_lj_speech_yinpin_test')
+ print(ds)
+ assert next(iter(ds))
+
+ @unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
+ 'skip test in current test level')
+ def test_inner_yuancheng_picture(self):
+ ds = MsDataset.load(
+ 'xxxxtest0004/yuancheng_picture',
+ subset_name='remote_images',
+ split='train')
+ print(next(iter(ds)))
+ assert next(iter(ds))
+
+
+if __name__ == '__main__':
+ unittest.main()