diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index a8565f16..5e40ec55 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -20,7 +20,8 @@ repos:
(?x)^(
examples/|
modelscope/utils/ast_index_file.py|
- modelscope/fileio/format/jsonplus.py
+ modelscope/fileio/format/jsonplus.py|
+ modelscope/msdatasets/utils/_module_factories\.py
)$
- repo: https://github.com/pre-commit/mirrors-yapf.git
rev: v0.30.0
@@ -31,7 +32,8 @@ repos:
thirdparty/|
examples/|
modelscope/utils/ast_index_file.py|
- modelscope/fileio/format/jsonplus.py
+ modelscope/fileio/format/jsonplus.py|
+ modelscope/msdatasets/utils/_module_factories\.py
)$
- repo: https://github.com/pre-commit/pre-commit-hooks.git
rev: v3.1.0
diff --git a/modelscope/msdatasets/utils/_compat.py b/modelscope/msdatasets/utils/_compat.py
new file mode 100644
index 00000000..05ec9e9b
--- /dev/null
+++ b/modelscope/msdatasets/utils/_compat.py
@@ -0,0 +1,288 @@
+# isort: skip_file
+# yapf: disable
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Compatibility shims for datasets>=4.0 script-based dataset loading.
+
+Script-based dataset loading was removed in datasets 4.0. This module
+provides minimal re-implementations of the necessary helpers so that
+ModelScope can still load datasets that ship a custom builder .py script.
+
+When running with datasets<4.0 the real implementations are simply
+re-exported from datasets.load / datasets.utils.py_utils.
+"""
+import importlib
+import os
+import sys
+from pathlib import Path
+from typing import List, Optional, Tuple
+
+from datasets import DownloadMode, config
+
+# ---------------------------------------------------------------------------
+# Try importing script-loading APIs from datasets<4.0
+# ---------------------------------------------------------------------------
+try:
+ from datasets.load import (
+ HubDatasetModuleFactoryWithScript,
+ LocalDatasetModuleFactoryWithScript,
+ resolve_trust_remote_code,
+ _get_importable_file_path,
+ _create_importable_file,
+ _load_importable_file,
+ init_dynamic_modules,
+ files_to_hash,
+ )
+ from datasets.utils.py_utils import get_imports
+
+ _HAS_SCRIPT_LOADING = True
+except ImportError:
+ _HAS_SCRIPT_LOADING = False
+ HubDatasetModuleFactoryWithScript = None # type: ignore[assignment,misc]
+ LocalDatasetModuleFactoryWithScript = None # type: ignore[assignment,misc]
+
+# ---------------------------------------------------------------------------
+# Compat implementations (only defined when datasets>=4.0)
+# ---------------------------------------------------------------------------
+if not _HAS_SCRIPT_LOADING:
+ import filecmp
+ import hashlib # noqa: F811 – only imported in this branch
+ import json as _json
+ import re
+ import shutil
+ from urllib.parse import urlparse
+
+ from datasets.packaged_modules import _hash_python_lines
+ from datasets.utils.file_utils import url_or_path_join
+ from datasets.utils.hub import hf_dataset_url # noqa: F401
+ from filelock import FileLock
+
+ def _compat_get_imports(
+ file_path: str) -> List[Tuple[str, str, str, Optional[str]]]:
+ """Parse a dataset script for import statements (ported from datasets<4.0)."""
+ with open(file_path, encoding='utf-8') as f:
+ lines = f.readlines()
+
+ imports: List[Tuple[str, str, str, Optional[str]]] = []
+ is_in_docstring = False
+ for line in lines:
+ docstr_start_match = re.findall(r'[\s\S]*?"""[\s\S]*?', line)
+ if len(docstr_start_match) == 1:
+ is_in_docstring = not is_in_docstring
+ if is_in_docstring:
+ continue
+ match = re.match(
+ r'^import\s+(\.?)([^\s\.]+)[^#\r\n]*(?:#\s+From:\s+)?([^\r\n]*)',
+ line,
+ flags=re.MULTILINE)
+ if match is None:
+ match = re.match(
+ r'^from\s+(\.?)([^\s\.]+)(?:[^\s]*)\s+import\s+[^#\r\n]*(?:#\s+From:\s+)?([^\r\n]*)',
+ line,
+ flags=re.MULTILINE)
+ if match is None:
+ continue
+ if match.group(1):
+ if any(imp[1] == match.group(2) for imp in imports):
+ continue
+ if match.group(3):
+ url_path = match.group(3)
+ url_path, sub_directory = _compat_convert_github_url(
+ url_path)
+ imports.append(
+ ('external', match.group(2), url_path, sub_directory))
+ elif match.group(2):
+ imports.append(
+ ('internal', match.group(2), match.group(2), None))
+ else:
+ if match.group(3):
+ imports.append(
+ ('library', match.group(2), match.group(3), None))
+ else:
+ imports.append(
+ ('library', match.group(2), match.group(2), None))
+ return imports
+
+ def _compat_convert_github_url(url_path: str) -> Tuple[str, Optional[str]]:
+ parsed = urlparse(url_path)
+ sub_directory = None
+ if parsed.scheme in ('http', 'https',
+ 's3') and parsed.netloc == 'github.com':
+ if 'blob' in url_path:
+ if not url_path.endswith('.py'):
+ raise ValueError(
+ f'External import from github at {url_path} should point to a .py file'
+ )
+ url_path = url_path.replace('blob', 'raw')
+ else:
+ github_path = parsed.path[1:]
+ repo_info, branch = (
+ github_path.split('/tree/') if '/tree/' in github_path else
+ (github_path, 'master'))
+ repo_owner, repo_name = repo_info.split('/')
+ url_path = f'https://github.com/{repo_owner}/{repo_name}/archive/{branch}.zip'
+ sub_directory = f'{repo_name}-{branch}'
+ return url_path, sub_directory
+
+ # -- dynamic module management ----------------------------------------
+
+ def _compat_init_dynamic_modules(
+ name: str = config.MODULE_NAME_FOR_DYNAMIC_MODULES,
+ hf_modules_cache=None,
+ ) -> str:
+ hf_modules_cache = str(hf_modules_cache or config.HF_MODULES_CACHE)
+ if hf_modules_cache not in sys.path:
+ sys.path.append(hf_modules_cache)
+ os.makedirs(hf_modules_cache, exist_ok=True)
+ init_path = os.path.join(hf_modules_cache, '__init__.py')
+ if not os.path.exists(init_path):
+ with open(init_path, 'w'):
+ pass
+ importlib.invalidate_caches()
+ dynamic_modules_path = os.path.join(hf_modules_cache, name)
+ os.makedirs(dynamic_modules_path, exist_ok=True)
+ init_path2 = os.path.join(dynamic_modules_path, '__init__.py')
+ if not os.path.exists(init_path2):
+ with open(init_path2, 'w'):
+ pass
+ return dynamic_modules_path
+
+ def _compat_files_to_hash(file_paths) -> str:
+ to_use_files: list = []
+ for fp in file_paths:
+ if os.path.isdir(fp):
+ to_use_files.extend(list(Path(fp).rglob('*.[pP][yY]')))
+ else:
+ to_use_files.append(fp)
+ lines: list = []
+ for fp in to_use_files:
+ with open(fp, encoding='utf-8') as f:
+ lines.extend(f.readlines())
+ return _hash_python_lines(lines)
+
+ # -- importable file management ---------------------------------------
+
+ def _compat_get_importable_file_path(
+ dynamic_modules_path: str,
+ module_namespace: str,
+ subdirectory_name: str,
+ name: str,
+ ) -> str:
+ importable_dir = os.path.join(dynamic_modules_path, module_namespace,
+ name.replace('/', '--'))
+ return os.path.join(importable_dir, subdirectory_name,
+ name.split('/')[-1] + '.py')
+
+ def _compat_copy_script_and_resources(
+ name: str,
+ importable_directory_path: str,
+ subdirectory_name: str,
+ original_local_path: str,
+ local_imports: List[Tuple[str, str]],
+ additional_files: List[Tuple[str, str]],
+ download_mode,
+ ) -> str:
+ importable_subdirectory = os.path.join(importable_directory_path,
+ subdirectory_name)
+ importable_file = os.path.join(importable_subdirectory, name + '.py')
+ lock_path = importable_directory_path + '.lock'
+ with FileLock(lock_path):
+ if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists(
+ importable_directory_path):
+ shutil.rmtree(importable_directory_path)
+ os.makedirs(importable_directory_path, exist_ok=True)
+ init_fp = os.path.join(importable_directory_path, '__init__.py')
+ if not os.path.exists(init_fp):
+ with open(init_fp, 'w'):
+ pass
+ os.makedirs(importable_subdirectory, exist_ok=True)
+ init_fp2 = os.path.join(importable_subdirectory, '__init__.py')
+ if not os.path.exists(init_fp2):
+ with open(init_fp2, 'w'):
+ pass
+ if not os.path.exists(importable_file):
+ shutil.copyfile(original_local_path, importable_file)
+ meta_path = os.path.splitext(importable_file)[0] + '.json'
+ if not os.path.exists(meta_path):
+ meta = {
+ 'original file path': original_local_path,
+ 'local file path': importable_file
+ }
+ with open(meta_path, 'w', encoding='utf-8') as mf:
+ _json.dump(meta, mf)
+ for imp_name, imp_path in local_imports:
+ if os.path.isfile(imp_path):
+ dest = os.path.join(importable_subdirectory,
+ imp_name + '.py')
+ if not os.path.exists(dest):
+ shutil.copyfile(imp_path, dest)
+ elif os.path.isdir(imp_path):
+ dest = os.path.join(importable_subdirectory, imp_name)
+ if not os.path.exists(dest):
+ shutil.copytree(imp_path, dest)
+ else:
+ raise ImportError(f'Error with local import at {imp_path}')
+ for file_name, original_path in additional_files:
+ dest_path = os.path.join(importable_subdirectory, file_name)
+ if not os.path.exists(dest_path) or not filecmp.cmp(
+ original_path, dest_path):
+ shutil.copyfile(original_path, dest_path)
+ return importable_file
+
+ def _compat_create_importable_file(
+ local_path: str,
+ local_imports: List[Tuple[str, str]],
+ additional_files: List[Tuple[str, str]],
+ dynamic_modules_path: str,
+ module_namespace: str,
+ subdirectory_name: str,
+ name: str,
+ download_mode,
+ ) -> None:
+ importable_dir = os.path.join(dynamic_modules_path, module_namespace,
+ name.replace('/', '--'))
+ Path(importable_dir).mkdir(parents=True, exist_ok=True)
+ (Path(importable_dir).parent / '__init__.py').touch(exist_ok=True)
+ _compat_copy_script_and_resources(
+ name=name.split('/')[-1],
+ importable_directory_path=importable_dir,
+ subdirectory_name=subdirectory_name,
+ original_local_path=local_path,
+ local_imports=local_imports,
+ additional_files=additional_files,
+ download_mode=download_mode,
+ )
+
+ def _compat_load_importable_file(
+ dynamic_modules_path: str,
+ module_namespace: str,
+ subdirectory_name: str,
+ name: str,
+ ) -> Tuple[str, str]:
+ module_path = '.'.join([
+ os.path.basename(dynamic_modules_path),
+ module_namespace,
+ name.replace('/', '--'),
+ subdirectory_name,
+ name.split('/')[-1],
+ ])
+ return module_path, subdirectory_name
+
+ # -- trust handling ---------------------------------------------------
+
+ def _compat_resolve_trust_remote_code(trust_remote_code, repo_id: str):
+ if trust_remote_code is None:
+ raise ValueError(
+ f'The repository for {repo_id} contains custom code which must be '
+ f'executed to correctly load the dataset. You can inspect the repository '
+ f'content at the Hub.\nPlease pass the argument `trust_remote_code=True` '
+ f'to allow custom code to be run.')
+ return trust_remote_code
+
+ # -- Assign compat functions to canonical names -----------------------
+ get_imports = _compat_get_imports # noqa: F811
+ init_dynamic_modules = _compat_init_dynamic_modules # noqa: F811
+ files_to_hash = _compat_files_to_hash # noqa: F811
+ resolve_trust_remote_code = _compat_resolve_trust_remote_code # noqa: F811
+ _get_importable_file_path = _compat_get_importable_file_path # noqa: F811
+ _create_importable_file = _compat_create_importable_file # noqa: F811
+ _load_importable_file = _compat_load_importable_file # noqa: F811
diff --git a/modelscope/msdatasets/utils/_module_factories.py b/modelscope/msdatasets/utils/_module_factories.py
new file mode 100644
index 00000000..861048ff
--- /dev/null
+++ b/modelscope/msdatasets/utils/_module_factories.py
@@ -0,0 +1,656 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+"""Dataset module factory functions and data file resolution for ModelScope.
+
+This module provides ModelScope-specific implementations of dataset module
+loading (both script-based and script-free) and data file pattern resolution.
+These functions are monkey-patched onto the ``datasets`` library internals
+by :func:`~hf_datasets_util.load_dataset_with_ctx`.
+"""
+import importlib
+import inspect
+import os
+from functools import partial
+from pathlib import Path
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+from datasets import (BuilderConfig, DownloadConfig, DownloadMode, Features,
+ Version, config, data_files)
+from datasets.data_files import (
+ FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError,
+ _get_data_files_patterns, _is_inside_unrequested_special_dir,
+ _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir,
+ sanitize_patterns)
+from datasets.download.streaming_download_manager import (
+ _prepare_path_and_storage_options, xbasename, xjoin)
+from datasets.exceptions import DataFilesNotFoundError
+from datasets.info import DatasetInfosDict
+from datasets.load import (BuilderConfigsParameters, DatasetModule,
+ create_builder_configs_from_metadata_configs,
+ get_dataset_builder_class, import_main_class,
+ infer_module_for_data_files)
+from datasets.naming import camelcase_to_snakecase
+from datasets.packaged_modules import (_MODULE_TO_EXTENSIONS,
+ _PACKAGED_DATASETS_MODULES)
+from datasets.utils.file_utils import (cached_path, is_local_path,
+ relative_to_absolute_path)
+from datasets.utils.metadata import MetadataConfigs
+from datasets.utils.track import tracked_str
+from fsspec import filesystem
+from fsspec.core import _un_chain
+from fsspec.utils import stringify_path
+from huggingface_hub import DatasetCard, DatasetCardData
+from packaging import version
+
+from modelscope import HubApi
+from modelscope.msdatasets.utils._compat import (
+ _HAS_SCRIPT_LOADING, _create_importable_file, _get_importable_file_path,
+ _load_importable_file, files_to_hash, get_imports, init_dynamic_modules,
+ resolve_trust_remote_code)
+from modelscope.utils.constant import (DEFAULT_DATASET_REVISION,
+ REPO_TYPE_DATASET)
+from modelscope.utils.file_utils import is_relative_path
+from modelscope.utils.import_utils import has_attr_in_class
+from modelscope.utils.logger import get_logger
+
+# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0
+try:
+ from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS
+except ImportError:
+ from datasets.load import ALL_ALLOWED_EXTENSIONS
+
+logger = get_logger()
+
+# ---------------------------------------------------------------------------
+# Shared HubApi instance (avoids creating a new requests.Session per call)
+# ---------------------------------------------------------------------------
+_hub_api: Optional[HubApi] = None
+
+
+def _get_hub_api() -> HubApi:
+ global _hub_api
+ if _hub_api is None:
+ _hub_api = HubApi(timeout=3 * 60, max_retries=3)
+ return _hub_api
+
+
+# ===================================================================
+# Data file resolution
+# ===================================================================
+
+
+def get_fs_token_paths(
+ urlpath,
+ storage_options=None,
+ protocol=None,
+):
+ if isinstance(urlpath, (list, tuple, set)):
+ if not urlpath:
+ raise ValueError('empty urlpath sequence')
+ urlpath0 = stringify_path(list(urlpath)[0])
+ else:
+ urlpath0 = stringify_path(urlpath)
+ storage_options = storage_options or {}
+ if protocol:
+ storage_options['protocol'] = protocol
+ chain = _un_chain(urlpath0, storage_options or {})
+ inkwargs = {}
+ for i, ch in enumerate(reversed(chain)):
+ urls, nested_protocol, kw = ch
+ if i == len(chain) - 1:
+ inkwargs = dict(**kw, **inkwargs)
+ continue
+ inkwargs['target_options'] = dict(**kw, **inkwargs)
+ inkwargs['target_protocol'] = nested_protocol
+ inkwargs['fo'] = urls
+ paths, protocol, _ = chain[0]
+ fs = filesystem(protocol, **inkwargs)
+ return fs
+
+
+def _resolve_pattern(
+ pattern: str,
+ base_path: str,
+ allowed_extensions: Optional[List[str]] = None,
+ download_config: Optional[DownloadConfig] = None,
+) -> List[str]:
+ """Resolve data file paths/URLs from a user-supplied pattern.
+
+ Supports ``*``, ``**``, and fsspec-based remote patterns (e.g. ``hf://``).
+ Hidden files/directories and ``__pycache__`` are excluded by default.
+ """
+ if is_relative_path(pattern):
+ pattern = xjoin(base_path, pattern)
+ elif is_local_path(pattern):
+ base_path = os.path.splitdrive(pattern)[0] + os.sep
+ else:
+ base_path = ''
+ pattern, storage_options = _prepare_path_and_storage_options(
+ pattern, download_config=download_config)
+ fs = get_fs_token_paths(pattern, storage_options=storage_options)
+ fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker
+ fs_pattern = pattern.split('::')[0].split('://')[-1]
+ files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
+ protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
+ protocol_prefix = protocol + '://' if protocol != 'file' else ''
+ glob_kwargs = {}
+ if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'):
+ glob_kwargs['expand_info'] = False
+
+ try:
+ tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
+ except FileNotFoundError:
+ raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
+
+ matched_paths = [
+ filepath if filepath.startswith(protocol_prefix) else protocol_prefix
+ + filepath for filepath, info in tmp_file_paths.items()
+ if info['type'] == 'file' and (
+ xbasename(filepath) not in files_to_ignore)
+ and not _is_inside_unrequested_special_dir(
+ os.path.relpath(filepath, fs_base_path),
+ os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504
+ not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504
+ os.path.relpath(filepath, fs_base_path),
+ os.path.relpath(fs_pattern, fs_base_path))
+ ]
+ if allowed_extensions is not None:
+ out = [
+ filepath for filepath in matched_paths
+ if any('.' + suffix in allowed_extensions
+ for suffix in xbasename(filepath).split('.')[1:])
+ ]
+ if len(out) < len(matched_paths):
+ invalid_matched_files = list(set(matched_paths) - set(out))
+ logger.info(
+ f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: "
+ f'{invalid_matched_files}')
+ else:
+ out = matched_paths
+ if not out:
+ error_msg = f"Unable to find '{pattern}'"
+ if allowed_extensions is not None:
+ error_msg += f' with any supported extension {list(allowed_extensions)}'
+ raise FileNotFoundError(error_msg)
+ return out
+
+
+def _get_data_patterns(
+ base_path: str,
+ download_config: Optional[DownloadConfig] = None
+) -> Dict[str, List[str]]:
+ """Get data file patterns for a dataset directory.
+
+ Tries ``SPLIT_PATTERN_SHARDED`` first, then falls back to
+ ``ALL_DEFAULT_PATTERNS``.
+ """
+ resolver = partial(
+ _resolve_pattern, base_path=base_path, download_config=download_config)
+ try:
+ return _get_data_files_patterns(resolver)
+ except FileNotFoundError:
+ raise EmptyDatasetError(
+ f"The directory at {base_path} doesn't contain any data files"
+ ) from None
+
+
+# ===================================================================
+# Repository file download helper
+# ===================================================================
+
+
+def _download_repo_file(
+ repo_id: str,
+ path_in_repo: str,
+ download_config: DownloadConfig,
+ revision: str,
+) -> str:
+ """Download a single file from a ModelScope dataset repository."""
+ api = _get_hub_api()
+ _namespace, _dataset_name = repo_id.split('/')
+ endpoint = api.get_endpoint_for_read(
+ repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
+ if download_config and download_config.download_desc is None:
+ download_config.download_desc = f'Downloading [{path_in_repo}]'
+ try:
+ url_or_filename = api.get_dataset_file_url(
+ file_name=path_in_repo,
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=revision,
+ extension_filter=False,
+ endpoint=endpoint,
+ )
+ repo_file_path = cached_path(
+ url_or_filename=url_or_filename, download_config=download_config)
+ except FileNotFoundError as e:
+ repo_file_path = ''
+ logger.error(e)
+ return repo_file_path
+
+
+# ===================================================================
+# Additional modules download (for script-based datasets)
+# ===================================================================
+
+
+def _download_additional_modules(
+ name: str,
+ dataset_name: str,
+ namespace: str,
+ revision: str,
+ imports: Tuple[str, str, str, str],
+ download_config: Optional[DownloadConfig],
+ trust_remote_code: Optional[bool] = False,
+) -> List[Tuple[str, str]]:
+ """Download additional modules referenced by a dataset builder script.
+
+ Parses the import list produced by ``get_imports`` and downloads any
+ internal (relative) or external modules. Library imports are validated
+ but not downloaded.
+ """
+ local_imports: List[Tuple[str, str]] = []
+ library_imports: List[Tuple[str, str]] = []
+
+ has_remote_code = any(
+ import_type in ('internal', 'external')
+ for import_type, _, _, _ in imports)
+ if has_remote_code and not trust_remote_code:
+ raise ValueError(
+ f'Loading {name} requires executing code from the repository. '
+ 'This is disabled by default for security reasons. '
+ 'If you trust the authors of this dataset, you can enable it with '
+ '`trust_remote_code=True`.')
+
+ api = _get_hub_api()
+ download_config = download_config.copy()
+ if download_config.download_desc is None:
+ download_config.download_desc = 'Downloading extra modules'
+
+ for import_type, import_name, import_path, sub_directory in imports:
+ if import_type == 'library':
+ library_imports.append((import_name, import_path))
+ continue
+ if import_name == name:
+ raise ValueError(
+ f'Error in the {name} script, importing relative {import_name} module '
+ f'but {import_name} is the name of the script. '
+ f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
+ f'comment pointing to the original relative import file path.')
+ if import_type == 'internal':
+ file_name = import_path + '.py'
+ url_or_filename = api.get_dataset_file_url(
+ file_name=file_name,
+ dataset_name=dataset_name,
+ namespace=namespace,
+ revision=revision,
+ )
+ elif import_type == 'external':
+ url_or_filename = import_path
+ else:
+ raise ValueError('Wrong import_type')
+
+ local_import_path = cached_path(
+ url_or_filename, download_config=download_config)
+ if sub_directory is not None:
+ local_import_path = os.path.join(local_import_path, sub_directory)
+ local_imports.append((import_name, local_import_path))
+
+ # Validate library imports
+ needs_to_be_installed = {}
+ for library_import_name, library_import_path in library_imports:
+ try:
+ importlib.import_module(library_import_name)
+ except ImportError:
+ if library_import_name not in needs_to_be_installed or library_import_path != library_import_name:
+ needs_to_be_installed[
+ library_import_name] = library_import_path
+ if needs_to_be_installed:
+ _dependencies_str = 'dependencies' if len(
+ needs_to_be_installed) > 1 else 'dependency'
+ _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it'
+ if 'sklearn' in needs_to_be_installed:
+ needs_to_be_installed['sklearn'] = 'scikit-learn'
+ if 'Bio' in needs_to_be_installed:
+ needs_to_be_installed['Bio'] = 'biopython'
+ raise ImportError(
+ f'To be able to use {name}, you need to install the following {_dependencies_str}: '
+ f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install "
+ f"{' '.join(needs_to_be_installed.values())}' for instance.")
+ return local_imports
+
+
+# ===================================================================
+# Module factory: script-based (Hub)
+# ===================================================================
+
+
+def _load_script_module(
+ repo_id: str,
+ revision: str,
+ download_config: DownloadConfig,
+ download_mode=None,
+ dynamic_modules_path: Optional[str] = None,
+ trust_remote_code: Optional[bool] = None,
+) -> DatasetModule:
+ """Shared implementation for loading a dataset module from a Hub .py script.
+
+ Used by both ``get_module_with_script`` (monkey-patch for datasets<4.0) and
+ ``_compat_hub_script_module`` (compat shim for datasets>=4.0).
+ """
+ _namespace, _dataset_name = repo_id.split('/')
+ script_file_name = f'{_dataset_name}.py'
+
+ local_script_path = _download_repo_file(
+ repo_id=repo_id,
+ path_in_repo=script_file_name,
+ download_config=download_config,
+ revision=revision,
+ )
+ if not local_script_path:
+ raise FileNotFoundError(
+ f'Cannot find {script_file_name} in {repo_id} at revision {revision}.'
+ )
+
+ dataset_readme_path = _download_repo_file(
+ repo_id=repo_id,
+ path_in_repo='README.md',
+ download_config=download_config,
+ revision=revision,
+ )
+
+ imports = get_imports(local_script_path)
+ local_imports = _download_additional_modules(
+ name=repo_id,
+ dataset_name=_dataset_name,
+ namespace=_namespace,
+ revision=revision,
+ imports=imports,
+ download_config=download_config,
+ trust_remote_code=trust_remote_code,
+ )
+
+ additional_files = []
+ if dataset_readme_path:
+ additional_files.append(
+ (config.REPOCARD_FILENAME, dataset_readme_path))
+
+ dynamic_modules_path = dynamic_modules_path or init_dynamic_modules()
+ hash_val = files_to_hash([local_script_path]
+ + [loc[1] for loc in local_imports])
+ importable_file_path = _get_importable_file_path(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=repo_id,
+ )
+ if not os.path.exists(importable_file_path):
+ trust = resolve_trust_remote_code(
+ trust_remote_code=trust_remote_code, repo_id=repo_id)
+ if trust:
+ logger.warning(
+ f'Use trust_remote_code=True. Will invoke codes from {repo_id}. '
+ 'Please make sure that you can trust the external codes.')
+ _create_importable_file(
+ local_path=local_script_path,
+ local_imports=local_imports,
+ additional_files=additional_files,
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=repo_id,
+ download_mode=download_mode,
+ )
+ else:
+ raise ValueError(
+ f'Loading {repo_id} requires executing the dataset script in that'
+ ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
+ ' set the option `trust_remote_code=True` to remove this error.'
+ )
+ module_path, hash_val = _load_importable_file(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=repo_id,
+ )
+ importlib.invalidate_caches()
+
+ api = _get_hub_api()
+ builder_kwargs = {
+ 'base_path': api.get_file_base_path(repo_id=repo_id),
+ 'repo_id': repo_id,
+ }
+ return DatasetModule(module_path, hash_val, builder_kwargs)
+
+
+def get_module_with_script(self) -> DatasetModule:
+ """Monkey-patch target for ``HubDatasetModuleFactoryWithScript.get_module`` (datasets<4.0)."""
+ repo_id: str = self.name
+ revision = self.download_config.storage_options.get(
+ 'revision', None) or DEFAULT_DATASET_REVISION
+ return _load_script_module(
+ repo_id=repo_id,
+ revision=revision,
+ download_config=self.download_config,
+ download_mode=self.download_mode,
+ dynamic_modules_path=self.dynamic_modules_path
+ if self.dynamic_modules_path else None,
+ trust_remote_code=self.trust_remote_code,
+ )
+
+
+def _compat_hub_script_module(
+ path,
+ revision=None,
+ download_config=None,
+ download_mode=None,
+ dynamic_modules_path=None,
+ trust_remote_code=None,
+) -> DatasetModule:
+ """Load a dataset module from a Hub repo .py script (compat for datasets>=4.0)."""
+ return _load_script_module(
+ repo_id=path,
+ revision=revision or DEFAULT_DATASET_REVISION,
+ download_config=download_config or DownloadConfig(),
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ )
+
+
+# ===================================================================
+# Module factory: script-based (local)
+# ===================================================================
+
+
+def _compat_local_script_module(
+ path,
+ download_mode=None,
+ dynamic_modules_path=None,
+ trust_remote_code=None,
+) -> DatasetModule:
+ """Load a dataset module from a local .py script (compat for datasets>=4.0)."""
+ local_path = path
+ name = Path(path).stem
+
+ local_imports: List[Tuple[str, str]] = []
+ imports = get_imports(local_path)
+ for import_type, import_name, import_path, sub_directory in imports:
+ if import_type == 'library':
+ continue
+ if import_type == 'internal':
+ rel_path = os.path.join(
+ os.path.dirname(local_path), import_path + '.py')
+ if os.path.isfile(rel_path):
+ local_imports.append((import_name, rel_path))
+ elif os.path.isdir(
+ os.path.join(os.path.dirname(local_path), import_path)):
+ local_imports.append(
+ (import_name,
+ os.path.join(os.path.dirname(local_path), import_path)))
+ elif import_type == 'external':
+ dl_config = DownloadConfig()
+ dl_config.download_desc = 'Downloading extra modules'
+ local_import_path = cached_path(
+ import_path, download_config=dl_config)
+ if sub_directory is not None:
+ local_import_path = os.path.join(local_import_path,
+ sub_directory)
+ local_imports.append((import_name, local_import_path))
+
+ dynamic_modules_path = dynamic_modules_path or init_dynamic_modules()
+ hash_val = files_to_hash([local_path] + [loc[1] for loc in local_imports])
+ importable_file_path = _get_importable_file_path(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=name,
+ )
+ if not os.path.exists(importable_file_path):
+ trust = resolve_trust_remote_code(trust_remote_code, name)
+ if trust:
+ _create_importable_file(
+ local_path=local_path,
+ local_imports=local_imports,
+ additional_files=[],
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=name,
+ download_mode=download_mode,
+ )
+ else:
+ raise ValueError(
+ f'Loading {name} requires executing the dataset script. '
+ 'Set `trust_remote_code=True` to allow this.')
+ module_path, hash_val = _load_importable_file(
+ dynamic_modules_path=dynamic_modules_path,
+ module_namespace='datasets',
+ subdirectory_name=hash_val,
+ name=name,
+ )
+ importlib.invalidate_caches()
+ builder_kwargs = {
+ 'base_path': str(Path(path).resolve().parent),
+ }
+ return DatasetModule(module_path, hash_val, builder_kwargs)
+
+
+# ===================================================================
+# Module factory: without script (Hub)
+# ===================================================================
+
+
+def get_module_without_script(self) -> DatasetModule:
+ """Monkey-patch target for ``HubDatasetModuleFactoryWithoutScript.get_module``."""
+ revision = self.download_config.storage_options.get(
+ 'revision', None) or DEFAULT_DATASET_REVISION
+ base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
+ '/')
+
+ repo_id: str = self.name
+ download_config = self.download_config.copy()
+
+ dataset_readme_path = _download_repo_file(
+ repo_id=repo_id,
+ path_in_repo='README.md',
+ download_config=download_config,
+ revision=revision,
+ )
+
+ dataset_card_data = DatasetCard.load(
+ Path(dataset_readme_path
+ )).data if dataset_readme_path else DatasetCardData()
+ subset_name: str = download_config.storage_options.get('name', None)
+
+ metadata_configs = MetadataConfigs.from_dataset_card_data(
+ dataset_card_data)
+ dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
+
+ if self.data_files is not None:
+ patterns = sanitize_patterns(self.data_files)
+ elif metadata_configs and 'data_files' in next(
+ iter(metadata_configs.values())):
+ if subset_name is not None:
+ subset_data_files = metadata_configs[subset_name]['data_files']
+ else:
+ subset_data_files = next(iter(
+ metadata_configs.values()))['data_files']
+ patterns = sanitize_patterns(subset_data_files)
+ else:
+ patterns = _get_data_patterns(
+ base_path, download_config=self.download_config)
+
+ data_files_dict = DataFilesDict.from_patterns(
+ patterns,
+ base_path=base_path,
+ allowed_extensions=ALL_ALLOWED_EXTENSIONS,
+ download_config=self.download_config,
+ )
+ module_name, default_builder_kwargs = infer_module_for_data_files(
+ data_files=data_files_dict,
+ path=self.name,
+ download_config=self.download_config,
+ )
+
+ if hasattr(data_files_dict, 'filter'):
+ data_files_dict = data_files_dict.filter(
+ extensions=_MODULE_TO_EXTENSIONS[module_name])
+ else:
+ data_files_dict = data_files_dict.filter_extensions(
+ _MODULE_TO_EXTENSIONS[module_name])
+
+ module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
+
+ if metadata_configs:
+ supports_metadata = module_name in {'imagefolder', 'audiofolder'}
+ create_builder_signature = inspect.signature(
+ create_builder_configs_from_metadata_configs)
+ in_args = {
+ 'module_path': module_path,
+ 'metadata_configs': metadata_configs,
+ 'base_path': base_path,
+ 'default_builder_kwargs': default_builder_kwargs,
+ 'download_config': self.download_config,
+ }
+ if 'supports_metadata' in create_builder_signature.parameters:
+ in_args['supports_metadata'] = supports_metadata
+ builder_configs, default_config_name = create_builder_configs_from_metadata_configs(
+ **in_args)
+ else:
+ builder_configs: List[BuilderConfig] = [
+ import_main_class(module_path).BUILDER_CONFIG_CLASS(
+ data_files=data_files_dict,
+ **default_builder_kwargs,
+ )
+ ]
+ default_config_name = None
+
+ api = _get_hub_api()
+ endpoint = api.get_endpoint_for_read(
+ repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
+
+ builder_kwargs = {
+ 'base_path':
+ api.get_file_base_path(repo_id=repo_id, endpoint=endpoint),
+ 'repo_id': self.name,
+ 'dataset_name': camelcase_to_snakecase(Path(self.name).name),
+ 'data_files': data_files_dict,
+ }
+ download_config = self.download_config.copy()
+ if download_config.download_desc is None:
+ download_config.download_desc = 'Downloading metadata'
+
+ if default_config_name is None and len(dataset_infos) == 1:
+ default_config_name = next(iter(dataset_infos))
+
+ return DatasetModule(
+ module_path,
+ revision,
+ builder_kwargs,
+ dataset_infos=dataset_infos,
+ builder_configs_parameters=BuilderConfigsParameters(
+ metadata_configs=metadata_configs,
+ builder_configs=builder_configs,
+ default_config_name=default_config_name,
+ ),
+ )
diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py
index bde8965b..1290cb26 100644
--- a/modelscope/msdatasets/utils/hf_datasets_util.py
+++ b/modelscope/msdatasets/utils/hf_datasets_util.py
@@ -1,27 +1,32 @@
# noqa: isort:skip_file, yapf: disable
# Copyright (c) Alibaba, Inc. and its affiliates.
# Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors.
-import importlib
+"""ModelScope dataset loading orchestration.
+
+This module provides :class:`DatasetsWrapperHF` and the
+:func:`load_dataset_with_ctx` context manager that monkey-patch the
+HuggingFace ``datasets`` library to work with the ModelScope Hub.
+
+Sub-modules:
+ _compat – backward-compat shims for datasets>=4.0 script loading
+ _module_factories – dataset module factory functions & data-file resolution
+"""
import contextlib
-import inspect
import os
import warnings
-from dataclasses import dataclass, field, fields
-from functools import partial
+from dataclasses import fields
from pathlib import Path
-from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar
+from typing import Any, Dict, Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union
from urllib.parse import urlencode
import requests
-from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
+from datasets import (Dataset, DatasetBuilder, DatasetDict,
DownloadConfig, DownloadManager, DownloadMode, Features,
IterableDataset, IterableDatasetDict, Split,
VerificationMode, Version, config, data_files, LargeList,
Sequence as SequenceHf)
-# In datasets 4.0+, Sequence was replaced by List as a feature type.
-# Use List as the base for ListMs when available, fall back to Sequence for <4.0.
try:
from datasets import List as DatasetList
except ImportError:
@@ -29,73 +34,25 @@ except ImportError:
from datasets.features import features
from datasets.features.features import _FEATURE_TYPES
-from datasets.data_files import (
- FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError,
- _get_data_files_patterns, _is_inside_unrequested_special_dir,
- _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, sanitize_patterns)
-from datasets.download.streaming_download_manager import (
- _prepare_path_and_storage_options, xbasename, xjoin)
+from datasets.data_files import DataFilesDict, EmptyDatasetError
from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
-from datasets.info import DatasetInfosDict
from datasets.load import (
- BuilderConfigsParameters,
CachedDatasetModuleFactory, DatasetModule,
HubDatasetModuleFactoryWithParquetExport,
PackagedDatasetModuleFactory,
- create_builder_configs_from_metadata_configs, get_dataset_builder_class,
- import_main_class, infer_module_for_data_files)
-
-# To compatible with datasets 4.0+
-try:
- from datasets.load import (
- HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript,
- LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript)
-except ImportError:
- from datasets.load import (
- HubDatasetModuleFactoryWithoutScript,
- LocalDatasetModuleFactoryWithoutScript)
-
-# Script-based dataset loading was removed in datasets 4.0.
-# These APIs are conditionally imported for backward compatibility with <4.0.
-try:
- from datasets.load import (
- HubDatasetModuleFactoryWithScript,
- LocalDatasetModuleFactoryWithScript,
- resolve_trust_remote_code,
- _get_importable_file_path, _create_importable_file,
- _load_importable_file, init_dynamic_modules,
- files_to_hash)
- from datasets.utils.py_utils import get_imports
- _HAS_SCRIPT_LOADING = True
-except ImportError:
- _HAS_SCRIPT_LOADING = False
-
-from datasets.naming import camelcase_to_snakecase
-from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
- _MODULE_TO_EXTENSIONS,
- _PACKAGED_DATASETS_MODULES)
-# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0
-try:
- from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS
-except ImportError:
- from datasets.load import ALL_ALLOWED_EXTENSIONS
+ get_dataset_builder_class)
+from datasets.packaged_modules import (_EXTENSION_TO_MODULE, _PACKAGED_DATASETS_MODULES)
from datasets.utils import file_utils
from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
- cached_path, is_local_path,
- relative_to_absolute_path)
+ cached_path, relative_to_absolute_path)
from datasets.utils.info_utils import is_small_dataset
-from datasets.utils.metadata import MetadataConfigs
from datasets.utils.track import tracked_str
-from fsspec import filesystem
-from fsspec.core import _un_chain
-from fsspec.utils import stringify_path
-from huggingface_hub import (DatasetCard, DatasetCardData, hf_hub_url)
+from huggingface_hub import hf_hub_url
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
from huggingface_hub.hf_file_system import HfFileSystem
-from packaging import version
from modelscope import HubApi
from modelscope.hub.utils.utils import get_endpoint
@@ -106,9 +63,41 @@ from modelscope.utils.import_utils import has_attr_in_class
from modelscope.utils.file_utils import is_relative_path
from modelscope.utils.logger import get_logger
+# -- Compat layer -----------------------------------------------------------
+from modelscope.msdatasets.utils._compat import (
+ _HAS_SCRIPT_LOADING,
+ HubDatasetModuleFactoryWithScript,
+ LocalDatasetModuleFactoryWithScript,
+)
+
+# -- Module factories --------------------------------------------------------
+from modelscope.msdatasets.utils._module_factories import (
+ _resolve_pattern,
+ _download_repo_file,
+ get_module_without_script,
+ get_module_with_script,
+ _compat_local_script_module,
+ _compat_hub_script_module,
+ _get_hub_api,
+)
+
+# Compatible with datasets 4.0+ (class name changed)
+try:
+ from datasets.load import (
+ HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript,
+ LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript)
+except ImportError:
+ from datasets.load import (
+ HubDatasetModuleFactoryWithoutScript,
+ LocalDatasetModuleFactoryWithoutScript)
+
logger = get_logger()
+# ===================================================================
+# Type definitions
+# ===================================================================
+
ExpandDatasetProperty_T = Literal[
'author',
'cardData',
@@ -129,33 +118,27 @@ ExpandDatasetProperty_T = Literal[
]
-# Patch datasets features
+# ===================================================================
+# Feature patching (generate_from_dict_ms)
+# ===================================================================
+
_NativeList = DatasetList if DatasetList is not None else SequenceHf
def generate_from_dict_ms(obj: Any):
"""Regenerate the nested feature object from a deserialized dict.
- We use the '_type' fields to get the dataclass name to load.
- generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax
- to define features from deserialized JSON dictionaries. This function is used in particular when deserializing
- a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to
- :meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any
- mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes
- that :class:`Value` automatically performs.
+ This is a ModelScope-patched version of ``features.generate_from_dict``
+ that handles backward compatibility for legacy ``Sequence`` types in
+ datasets 4.0+ where ``Sequence`` is no longer a registered feature type.
"""
- # Nested structures: we allow dict, list/tuples, sequences
if isinstance(obj, list):
return [generate_from_dict_ms(value) for value in obj]
- # Otherwise we have a dict or a dataclass
if '_type' not in obj or isinstance(obj['_type'], dict):
return {key: generate_from_dict_ms(value) for key, value in obj.items()}
obj = dict(obj)
_type = obj.pop('_type')
- # Handle legacy 'Sequence' type for backward compatibility.
- # In datasets 4.0+, Sequence is a utility function (not a feature type),
- # so it may not be registered in _FEATURE_TYPES.
if _type == 'Sequence':
feature = obj.pop('feature')
length = obj.get('length', -1)
@@ -169,7 +152,6 @@ def generate_from_dict_ms(obj: Any):
if class_type == LargeList:
feature = obj.pop('feature')
return LargeList(generate_from_dict_ms(feature), **obj)
- # Handle the native List type (datasets 4.0+) as well as Sequence-based
if _NativeList is not None and (class_type is _NativeList or issubclass(class_type, _NativeList)):
feature = obj.pop('feature')
return _NativeList(generate_from_dict_ms(feature), **obj)
@@ -178,17 +160,22 @@ def generate_from_dict_ms(obj: Any):
return class_type(**{k: v for k, v in obj.items() if k in field_names})
+# ===================================================================
+# Download monkey-patch (_download_ms)
+# ===================================================================
+
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
+ """ModelScope replacement for ``DownloadManager._download``.
+
+ Rewrites relative paths and ``hf://`` URLs to ModelScope API endpoints.
+ """
url_or_filename = str(url_or_filename)
if url_or_filename.startswith('hf://'):
- # hf:// URLs (e.g. hf://datasets/{owner}/{name}@{revision}/{file_path})
hf_path = url_or_filename[len('hf://'):]
- # Strip leading resource type prefix (e.g. "datasets/")
for _prefix in ('datasets/', 'models/'):
if hf_path.startswith(_prefix):
hf_path = hf_path[len(_prefix):]
break
- # Extract revision and file_path from "{owner}/{name}@{revision}/{file_path}"
if '@' in hf_path:
at_idx = hf_path.index('@')
after_at = hf_path[at_idx + 1:]
@@ -203,14 +190,12 @@ def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) ->
parts = hf_path.split('/', 2)
revision = DEFAULT_DATASET_REVISION
file_path = parts[2] if len(parts) > 2 else ''
- params = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path})
- url_or_filename = self._base_path + params
+ params_str = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path})
+ url_or_filename = self._base_path + params_str
elif is_relative_path(url_or_filename):
revision = DEFAULT_DATASET_REVISION
- # Note: make sure the FilePath is the last param
- params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
- params: str = urlencode(params)
- url_or_filename = self._base_path + params
+ params_str = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename})
+ url_or_filename = self._base_path + params_str
out = cached_path(url_or_filename, download_config=download_config)
out = tracked_str(out)
@@ -218,6 +203,10 @@ def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) ->
return out
+# ===================================================================
+# HfApi monkey-patches (dataset_info, list_repo_tree, get_paths_info)
+# ===================================================================
+
def _dataset_info(
self,
repo_id: str,
@@ -228,45 +217,7 @@ def _dataset_info(
token: Optional[Union[bool, str]] = None,
expand: Optional[List[ExpandDatasetProperty_T]] = None,
) -> HfDatasetInfo:
- """
- Get info on one specific dataset on huggingface.co.
-
- Dataset can be private if you pass an acceptable token.
-
- Args:
- repo_id (`str`):
- A namespace (user or an organization) and a repo name separated
- by a `/`.
- revision (`str`, *optional*):
- The revision of the dataset repository from which to get the
- information.
- timeout (`float`, *optional*):
- Whether to set a timeout for the request to the Hub.
- files_metadata (`bool`, *optional*):
- Whether or not to retrieve metadata for files in the repository
- (size, LFS metadata, etc). Defaults to `False`.
- token (`bool` or `str`, *optional*):
- A valid authentication token (see https://huggingface.co/settings/token).
- If `None` or `True` and machine is logged in (through `huggingface-cli login`
- or [`~huggingface_hub.login`]), token will be retrieved from the cache.
- If `False`, token is not sent in the request header.
-
- Returns:
- [`hf_api.DatasetInfo`]: The dataset repository information.
-
-
-
- Raises the following errors:
-
- - [`~utils.RepositoryNotFoundError`]
- If the repository to download from cannot be found. This may be because it doesn't exist,
- or because it is set to `private` and you do not have access.
- - [`~utils.RevisionNotFoundError`]
- If the revision to download from cannot be found.
-
-
- """
- # Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
+ """ModelScope replacement for ``HfApi.dataset_info``."""
repo_info_iter = self.list_repo_tree(
repo_id=repo_id,
path_in_repo='/',
@@ -277,22 +228,21 @@ def _dataset_info(
repo_type=REPO_TYPE_DATASET,
)
- # Update data_info
- data_info = dict({})
- data_info['id'] = repo_id
- data_info['private'] = False
- data_info['author'] = repo_id.split('/')[0] if repo_id else None
- data_info['sha'] = revision
- data_info['lastModified'] = None
- data_info['gated'] = False
- data_info['disabled'] = False
- data_info['downloads'] = 0
- data_info['likes'] = 0
- data_info['tags'] = []
- data_info['cardData'] = []
- data_info['createdAt'] = None
+ data_info = {
+ 'id': repo_id,
+ 'private': False,
+ 'author': repo_id.split('/')[0] if repo_id else None,
+ 'sha': revision,
+ 'lastModified': None,
+ 'gated': False,
+ 'disabled': False,
+ 'downloads': 0,
+ 'likes': 0,
+ 'tags': [],
+ 'cardData': [],
+ 'createdAt': None,
+ }
- # e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}}
data_siblings = []
for info_item in repo_info_iter:
if isinstance(info_item, RepoFile):
@@ -308,6 +258,8 @@ def _dataset_info(
return HfDatasetInfo(**data_info)
+# -- Repo tree cache ---------------------------------------------------------
+
_repo_tree_cache: Dict[tuple, List[Union[RepoFile, RepoFolder]]] = {}
@@ -350,7 +302,7 @@ def _list_repo_tree(
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> Iterable[Union[RepoFile, RepoFolder]]:
-
+ """ModelScope replacement for ``HfApi.list_repo_tree``."""
revision = revision or DEFAULT_DATASET_REVISION
normalized_path = path_in_repo or '/'
cache_key = (repo_id, revision, normalized_path, recursive)
@@ -366,24 +318,21 @@ def _list_repo_tree(
yield from derived
return
- _api = HubApi(timeout=3 * 60, max_retries=3)
- endpoint = _api.get_endpoint_for_read(
+ api = _get_hub_api()
+ endpoint = api.get_endpoint_for_read(
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
_owner, _dataset_name = repo_id.split('/')
- dataset_hub_id, _ = _api.get_dataset_id_and_type(
+ dataset_hub_id, _ = api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
results: List[Union[RepoFile, RepoFolder]] = []
page_number = 1
- # Larger page_size reduces the number of HTTP round-trips for big datasets.
- # Termination uses `not dataset_files` (empty page) so it is safe even if
- # the server silently caps the actual page size to a smaller value.
page_size = 500
max_pages = 10000
while page_number <= max_pages:
try:
- dataset_files = _api.get_dataset_files(
+ dataset_files = api.get_dataset_files(
repo_id=repo_id,
revision=revision,
root_path=normalized_path,
@@ -428,19 +377,22 @@ def _get_paths_info(
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> List[Union[RepoFile, RepoFolder]]:
-
+ """ModelScope replacement for ``HfApi.get_paths_info``."""
revision = revision or DEFAULT_DATASET_REVISION
if isinstance(paths, str):
paths = [paths]
paths_set = set(paths)
- # Search within any cached tree data (recursive root is the most comprehensive)
- root_key = (repo_id, revision, '/', True)
- root_cached = _repo_tree_cache.get(root_key)
- if root_cached is not None:
- matched = [item for item in root_cached if item.path in paths_set]
+ # Check all available caches for matching paths
+ for cache_key, cached_items in _repo_tree_cache.items():
+ if cache_key[0] != repo_id or cache_key[1] != revision:
+ continue
+ matched = [item for item in cached_items if item.path in paths_set]
if matched:
return matched
+ # Recursive root cache is authoritative – if paths not found, they don't exist
+ if cache_key == (repo_id, revision, '/', True):
+ return []
repo_info_iter = self.list_repo_tree(
repo_id=repo_id,
@@ -451,598 +403,62 @@ def _get_paths_info(
token=token,
)
- return [item_info for item_info in repo_info_iter]
+ return [item for item in repo_info_iter if item.path in paths_set]
-def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str):
- _api = HubApi()
- _namespace, _dataset_name = repo_id.split('/')
- endpoint = _api.get_endpoint_for_read(
- repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
- if download_config and download_config.download_desc is None:
- download_config.download_desc = f'Downloading [{path_in_repo}]'
- try:
- url_or_filename = _api.get_dataset_file_url(
- file_name=path_in_repo,
- dataset_name=_dataset_name,
- namespace=_namespace,
- revision=revision,
- extension_filter=False,
- endpoint=endpoint
- )
- repo_file_path = cached_path(
- url_or_filename=url_or_filename, download_config=download_config)
- except FileNotFoundError as e:
- repo_file_path = ''
- logger.error(e)
+# ===================================================================
+# HfFileSystem patch (_hf_fs_open)
+# ===================================================================
- return repo_file_path
+_hf_fs_open_original = None
-def get_fs_token_paths(
- urlpath,
- storage_options=None,
- protocol=None,
-):
- if isinstance(urlpath, (list, tuple, set)):
- if not urlpath:
- raise ValueError('empty urlpath sequence')
- urlpath0 = stringify_path(list(urlpath)[0])
- else:
- urlpath0 = stringify_path(urlpath)
- storage_options = storage_options or {}
- if protocol:
- storage_options['protocol'] = protocol
- chain = _un_chain(urlpath0, storage_options or {})
- inkwargs = {}
- # Reverse iterate the chain, creating a nested target_* structure
- for i, ch in enumerate(reversed(chain)):
- urls, nested_protocol, kw = ch
- if i == len(chain) - 1:
- inkwargs = dict(**kw, **inkwargs)
- continue
- inkwargs['target_options'] = dict(**kw, **inkwargs)
- inkwargs['target_protocol'] = nested_protocol
- inkwargs['fo'] = urls
- paths, protocol, _ = chain[0]
- fs = filesystem(protocol, **inkwargs)
+def _hf_fs_open(self, path, mode='rb', **kwargs):
+ """Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API.
- return fs
-
-
-def _resolve_pattern(
- pattern: str,
- base_path: str,
- allowed_extensions: Optional[List[str]] = None,
- download_config: Optional[DownloadConfig] = None,
-) -> List[str]:
+ The ModelScope tree API may report Size=0 for files. When HfFileSystem
+ caches this, AbstractBufferedFile treats the file as empty (0 bytes).
+ This wrapper detects size=0 for files opened in read mode and resolves
+ the actual size via a HEAD request before creating the file object.
"""
- Resolve the paths and URLs of the data files from the pattern passed by the user.
-
- You can use patterns to resolve multiple local files. Here are a few examples:
- - *.csv to match all the CSV files at the first level
- - **.csv to match all the CSV files at any level
- - data/* to match all the files inside "data"
- - data/** to match all the files inside "data" and its subdirectories
-
- The patterns are resolved using the fsspec glob.
-
- glob.glob, Path.glob, Path.match or fnmatch do not support ** with a prefix/suffix other than a forward slash /.
- For instance, this means **.json is the same as *.json. On the contrary, the fsspec glob has no limits regarding the ** prefix/suffix, # noqa: E501
- resulting in **.json being equivalent to **/*.json.
-
- More generally:
- - '*' matches any character except a forward-slash (to match just the file or directory name)
- - '**' matches any character including a forward-slash /
-
- Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested.
- The same applies to special directories that start with a double underscore like "__pycache__".
- You can still include one if the pattern explicitly mentions it:
- - to include a hidden file: "*/.hidden.txt" or "*/.*"
- - to include a hidden directory: ".hidden/*" or ".*/*"
- - to include a special directory: "__special__/*" or "__*/*"
-
- Example::
-
- >>> from datasets.data_files import resolve_pattern
- >>> base_path = "."
- >>> resolve_pattern("docs/**/*.py", base_path)
- [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py']
-
- Args:
- pattern (str): Unix pattern or paths or URLs of the data files to resolve.
- The paths can be absolute or relative to base_path.
- Remote filesystems using fsspec are supported, e.g. with the hf:// protocol.
- base_path (str): Base path to use when resolving relative paths.
- allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions).
- For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"]
- Returns:
- List[str]: List of paths or URLs to the local or remote files that match the patterns.
- """
- if is_relative_path(pattern):
- pattern = xjoin(base_path, pattern)
- elif is_local_path(pattern):
- base_path = os.path.splitdrive(pattern)[0] + os.sep
- else:
- base_path = ''
- # storage_options: {'hf': {'token': None, 'endpoint': 'https://huggingface.co'}}
- pattern, storage_options = _prepare_path_and_storage_options(
- pattern, download_config=download_config)
- fs = get_fs_token_paths(pattern, storage_options=storage_options)
- fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker
- fs_pattern = pattern.split('::')[0].split('://')[-1]
- files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)}
- protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0]
- protocol_prefix = protocol + '://' if protocol != 'file' else ''
- glob_kwargs = {}
- if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'):
- # 10 times faster glob with detail=True (ignores costly info like lastCommit)
- glob_kwargs['expand_info'] = False
-
- try:
- tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs)
- except FileNotFoundError:
- raise DataFilesNotFoundError(f"Unable to find '{pattern}'")
-
- matched_paths = [
- filepath if filepath.startswith(protocol_prefix) else protocol_prefix
- + filepath for filepath, info in tmp_file_paths.items()
- if info['type'] == 'file' and (
- xbasename(filepath) not in files_to_ignore)
- and not _is_inside_unrequested_special_dir(
- os.path.relpath(filepath, fs_base_path),
- os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504
- not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504
- os.path.relpath(filepath, fs_base_path),
- os.path.relpath(fs_pattern, fs_base_path))
- ] # ignore .ipynb and __pycache__, but keep /../
- if allowed_extensions is not None:
- out = [
- filepath for filepath in matched_paths
- if any('.' + suffix in allowed_extensions
- for suffix in xbasename(filepath).split('.')[1:])
- ]
- if len(out) < len(matched_paths):
- invalid_matched_files = list(set(matched_paths) - set(out))
- logger.info(
- f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: "
- f'{invalid_matched_files}')
- else:
- out = matched_paths
- if not out:
- error_msg = f"Unable to find '{pattern}'"
- if allowed_extensions is not None:
- error_msg += f' with any supported extension {list(allowed_extensions)}'
- raise FileNotFoundError(error_msg)
- return out
-
-
-def _get_data_patterns(
- base_path: str,
- download_config: Optional[DownloadConfig] = None) -> Dict[str,
- List[str]]:
- """
- Get the default pattern from a directory testing all the supported patterns.
- The first patterns to return a non-empty list of data files is returned.
-
- Some examples of supported patterns:
-
- Input:
-
- my_dataset_repository/
- ├── README.md
- └── dataset.csv
-
- Output:
-
- {"train": ["**"]}
-
- Input:
-
- my_dataset_repository/
- ├── README.md
- ├── train.csv
- └── test.csv
-
- my_dataset_repository/
- ├── README.md
- └── data/
- ├── train.csv
- └── test.csv
-
- my_dataset_repository/
- ├── README.md
- ├── train_0.csv
- ├── train_1.csv
- ├── train_2.csv
- ├── train_3.csv
- ├── test_0.csv
- └── test_1.csv
-
- Output:
-
- {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
- 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
- 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
- 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
-
- Input:
-
- my_dataset_repository/
- ├── README.md
- └── data/
- ├── train/
- │ ├── shard_0.csv
- │ ├── shard_1.csv
- │ ├── shard_2.csv
- │ └── shard_3.csv
- └── test/
- ├── shard_0.csv
- └── shard_1.csv
-
- Output:
-
- {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**',
- 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'],
- 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**',
- 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]}
-
- Input:
-
- my_dataset_repository/
- ├── README.md
- └── data/
- ├── train-00000-of-00003.csv
- ├── train-00001-of-00003.csv
- ├── train-00002-of-00003.csv
- ├── test-00000-of-00001.csv
- ├── random-00000-of-00003.csv
- ├── random-00001-of-00003.csv
- └── random-00002-of-00003.csv
-
- Output:
-
- {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
- 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'],
- 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']}
-
- In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS.
- """
- resolver = partial(
- _resolve_pattern, base_path=base_path, download_config=download_config)
- try:
- return _get_data_files_patterns(resolver)
- except FileNotFoundError:
- raise EmptyDatasetError(
- f"The directory at {base_path} doesn't contain any data files"
- ) from None
-
-
-def get_module_without_script(self) -> DatasetModule:
-
- # hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info(
- # self.name,
- # revision=self.revision,
- # token=self.download_config.token,
- # timeout=100.0,
- # )
- # even if metadata_configs is not None (which means that we will resolve files for each config later)
- # we cannot skip resolving all files because we need to infer module name by files extensions
- # revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime
- revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
- base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip(
- '/')
-
- repo_id: str = self.name
- download_config = self.download_config.copy()
-
- dataset_readme_path = _download_repo_file(
- repo_id=repo_id,
- path_in_repo='README.md',
- download_config=download_config,
- revision=revision)
-
- dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data if dataset_readme_path else DatasetCardData()
- subset_name: str = download_config.storage_options.get('name', None)
-
- metadata_configs = MetadataConfigs.from_dataset_card_data(
- dataset_card_data)
- dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data)
- # we need a set of data files to find which dataset builder to use
- # because we need to infer module name by files extensions
- if self.data_files is not None:
- patterns = sanitize_patterns(self.data_files)
- elif metadata_configs and 'data_files' in next(
- iter(metadata_configs.values())):
-
- if subset_name is not None:
- subset_data_files = metadata_configs[subset_name]['data_files']
- else:
- subset_data_files = next(iter(metadata_configs.values()))['data_files']
- patterns = sanitize_patterns(subset_data_files)
- else:
- patterns = _get_data_patterns(
- base_path, download_config=self.download_config)
-
- data_files = DataFilesDict.from_patterns(
- patterns,
- base_path=base_path,
- allowed_extensions=ALL_ALLOWED_EXTENSIONS,
- download_config=self.download_config,
- )
- module_name, default_builder_kwargs = infer_module_for_data_files(
- data_files=data_files,
- path=self.name,
- download_config=self.download_config,
- )
-
- if hasattr(data_files, 'filter'):
- data_files = data_files.filter(extensions=_MODULE_TO_EXTENSIONS[module_name])
- else:
- data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name])
-
- module_path, _ = _PACKAGED_DATASETS_MODULES[module_name]
-
- if metadata_configs:
-
- supports_metadata = module_name in {'imagefolder', 'audiofolder'}
- create_builder_signature = inspect.signature(create_builder_configs_from_metadata_configs)
- in_args = {
- 'module_path': module_path,
- 'metadata_configs': metadata_configs,
- 'base_path': base_path,
- 'default_builder_kwargs': default_builder_kwargs,
- 'download_config': self.download_config,
- }
- if 'supports_metadata' in create_builder_signature.parameters:
- in_args['supports_metadata'] = supports_metadata
-
- builder_configs, default_config_name = create_builder_configs_from_metadata_configs(**in_args)
- else:
- builder_configs: List[BuilderConfig] = [
- import_main_class(module_path).BUILDER_CONFIG_CLASS(
- data_files=data_files,
- **default_builder_kwargs,
- )
- ]
- default_config_name = None
- _api = HubApi()
- endpoint = _api.get_endpoint_for_read(
- repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
-
- builder_kwargs = {
- # "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"),
- 'base_path':
- HubApi().get_file_base_path(repo_id=repo_id, endpoint=endpoint),
- 'repo_id':
- self.name,
- 'dataset_name':
- camelcase_to_snakecase(Path(self.name).name),
- 'data_files': data_files,
- }
- download_config = self.download_config.copy()
- if download_config.download_desc is None:
- download_config.download_desc = 'Downloading metadata'
-
- # Note: `dataset_infos.json` is deprecated and can cause an error during loading if it exists
-
- if default_config_name is None and len(dataset_infos) == 1:
- default_config_name = next(iter(dataset_infos))
-
- hash = revision
- return DatasetModule(
- module_path,
- hash,
- builder_kwargs,
- dataset_infos=dataset_infos,
- builder_configs_parameters=BuilderConfigsParameters(
- metadata_configs=metadata_configs,
- builder_configs=builder_configs,
- default_config_name=default_config_name,
- ),
- )
-
-
-def _download_additional_modules(
- name: str,
- dataset_name: str,
- namespace: str,
- revision: str,
- imports: Tuple[str, str, str, str],
- download_config: Optional[DownloadConfig],
- trust_remote_code: Optional[bool] = False,
-) -> List[Tuple[str, str]]:
- """
- Download additional module for a module .py at URL (or local path) /.py
- The imports must have been parsed first using ``get_imports``.
-
- If some modules need to be installed with pip, an error is raised showing how to install them.
- This function return the list of downloaded modules as tuples (import_name, module_file_path).
-
- The downloaded modules can then be moved into an importable directory
- with ``_copy_script_and_other_resources_in_importable_dir``.
- """
- local_imports = []
- library_imports = []
-
- # Check if we need to execute remote code
- has_remote_code = any(
- import_type in ('internal', 'external')
- for import_type, _, _, _ in imports
- )
-
- if has_remote_code and not trust_remote_code:
- raise ValueError(
- f'Loading {name} requires executing code from the repository. '
- 'This is disabled by default for security reasons. '
- 'If you trust the authors of this dataset, you can enable it with '
- '`trust_remote_code=True`.'
- )
-
- download_config = download_config.copy()
- if download_config.download_desc is None:
- download_config.download_desc = 'Downloading extra modules'
- for import_type, import_name, import_path, sub_directory in imports:
- if import_type == 'library':
- library_imports.append((import_name, import_path)) # Import from a library
- continue
-
- if import_name == name:
- raise ValueError(
- f'Error in the {name} script, importing relative {import_name} module '
- f'but {import_name} is the name of the script. '
- f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' "
- f'comment pointing to the original relative import file path.'
- )
- if import_type == 'internal':
- _api = HubApi()
- # url_or_filename = url_or_path_join(base_path, import_path + ".py")
- file_name = import_path + '.py'
- url_or_filename = _api.get_dataset_file_url(file_name=file_name,
- dataset_name=dataset_name,
- namespace=namespace,
- revision=revision,)
- elif import_type == 'external':
- url_or_filename = import_path
- else:
- raise ValueError('Wrong import_type')
-
- local_import_path = cached_path(
- url_or_filename,
- download_config=download_config,
- )
- if sub_directory is not None:
- local_import_path = os.path.join(local_import_path, sub_directory)
- local_imports.append((import_name, local_import_path))
-
- # Check library imports
- needs_to_be_installed = {}
- for library_import_name, library_import_path in library_imports:
+ if mode == 'rb' and 'size' not in kwargs:
try:
- lib = importlib.import_module(library_import_name) # noqa F841
- except ImportError:
- if library_import_name not in needs_to_be_installed or library_import_path != library_import_name:
- needs_to_be_installed[library_import_name] = library_import_path
- if needs_to_be_installed:
- _dependencies_str = 'dependencies' if len(needs_to_be_installed) > 1 else 'dependency'
- _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it'
- if 'sklearn' in needs_to_be_installed.keys():
- needs_to_be_installed['sklearn'] = 'scikit-learn'
- if 'Bio' in needs_to_be_installed.keys():
- needs_to_be_installed['Bio'] = 'biopython'
- raise ImportError(
- f'To be able to use {name}, you need to install the following {_dependencies_str}: '
- f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install "
- f"{' '.join(needs_to_be_installed.values())}' for instance."
- )
- return local_imports
+ resolved = self.resolve_path(path)
+ resolved_name = resolved.unresolve()
+ parent = self._parent(resolved_name)
+ cached_size = None
+ if parent in self.dircache:
+ for entry in self.dircache[parent]:
+ if entry['name'] == resolved_name and entry.get('type') == 'file':
+ cached_size = entry.get('size', -1)
+ break
+ if cached_size == 0:
+ url = hf_hub_url(
+ repo_id=resolved.repo_id,
+ revision=resolved.revision,
+ filename=resolved.path_in_repo,
+ repo_type=resolved.repo_type,
+ endpoint=self.endpoint,
+ )
+ headers = self._api._build_hf_headers()
+ resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30)
+ if resp.status_code == 200:
+ cl = resp.headers.get('Content-Length')
+ if cl:
+ actual_size = int(cl)
+ kwargs['size'] = actual_size
+ for entry in self.dircache.get(parent, []):
+ if entry['name'] == resolved_name:
+ entry['size'] = actual_size
+ break
+ except Exception:
+ pass
+ return _hf_fs_open_original(self, path, mode=mode, **kwargs)
-def get_module_with_script(self) -> DatasetModule:
- if not _HAS_SCRIPT_LOADING:
- raise RuntimeError(
- 'Script-based dataset loading is not supported with datasets>=4.0. '
- 'Please convert the dataset to a script-free format (e.g. Parquet).')
-
- repo_id: str = self.name
- _namespace, _dataset_name = repo_id.split('/')
- revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION
-
- script_file_name = f'{_dataset_name}.py'
- local_script_path = _download_repo_file(
- repo_id=repo_id,
- path_in_repo=script_file_name,
- download_config=self.download_config,
- revision=revision,
- )
- if not local_script_path:
- raise FileNotFoundError(
- f'Cannot find {script_file_name} in {repo_id} at revision {revision}. '
- f'Please create {script_file_name} in the repo.'
- )
-
- dataset_infos_path = None
- # try:
- # dataset_infos_url: str = _api.get_dataset_file_url(
- # file_name='dataset_infos.json',
- # dataset_name=_dataset_name,
- # namespace=_namespace,
- # revision=self.revision,
- # extension_filter=False,
- # )
- # dataset_infos_path = cached_path(
- # url_or_filename=dataset_infos_url, download_config=self.download_config)
- # except Exception as e:
- # logger.info(f'Cannot find dataset_infos.json: {e}')
- # dataset_infos_path = None
-
- dataset_readme_path = _download_repo_file(
- repo_id=repo_id,
- path_in_repo='README.md',
- download_config=self.download_config,
- revision=revision
- )
-
- imports = get_imports(local_script_path)
- local_imports = _download_additional_modules(
- name=repo_id,
- dataset_name=_dataset_name,
- namespace=_namespace,
- revision=revision,
- imports=imports,
- download_config=self.download_config,
- trust_remote_code=self.trust_remote_code,
- )
- additional_files = []
- if dataset_infos_path:
- additional_files.append((config.DATASETDICT_INFOS_FILENAME, dataset_infos_path))
- if dataset_readme_path:
- additional_files.append((config.REPOCARD_FILENAME, dataset_readme_path))
- # copy the script and the files in an importable directory
- dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules()
- hash = files_to_hash([local_script_path] + [loc[1] for loc in local_imports])
- importable_file_path = _get_importable_file_path(
- dynamic_modules_path=dynamic_modules_path,
- module_namespace='datasets',
- subdirectory_name=hash,
- name=repo_id,
- )
- if not os.path.exists(importable_file_path):
- trust_remote_code = resolve_trust_remote_code(trust_remote_code=self.trust_remote_code, repo_id=self.name)
- if trust_remote_code:
- logger.warning(f'Use trust_remote_code=True. Will invoke codes from {repo_id}. Please make sure that '
- 'you can trust the external codes.')
- _create_importable_file(
- local_path=local_script_path,
- local_imports=local_imports,
- additional_files=additional_files,
- dynamic_modules_path=dynamic_modules_path,
- module_namespace='datasets',
- subdirectory_name=hash,
- name=repo_id,
- download_mode=self.download_mode,
- )
- else:
- raise ValueError(
- f'Loading {repo_id} requires you to execute the dataset script in that'
- ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then'
- ' set the option `trust_remote_code=True` to remove this error.'
- )
- module_path, hash = _load_importable_file(
- dynamic_modules_path=dynamic_modules_path,
- module_namespace='datasets',
- subdirectory_name=hash,
- name=repo_id,
- )
- # make the new module to be noticed by the import system
- importlib.invalidate_caches()
- builder_kwargs = {
- # "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"),
- 'base_path': HubApi().get_file_base_path(repo_id=repo_id),
- 'repo_id': repo_id,
- }
- return DatasetModule(module_path, hash, builder_kwargs)
-
+# ===================================================================
+# DatasetsWrapperHF
+# ===================================================================
class DatasetsWrapperHF:
@@ -1093,8 +509,7 @@ class DatasetsWrapperHF:
raise ValueError(
f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default)."
)
- if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(
- ):
+ if Path(path, config.DATASET_STATE_JSON_FILENAME).exists():
raise ValueError(
'You are trying to load a dataset that was saved using `save_to_disk`. '
'Please use `load_from_disk` instead.')
@@ -1112,14 +527,9 @@ class DatasetsWrapperHF:
) if not save_infos else VerificationMode.ALL_CHECKS)
if trust_remote_code:
- if not _HAS_SCRIPT_LOADING:
- logger.warning('trust_remote_code is ignored: script-based dataset loading '
- 'is no longer supported with datasets>=4.0.')
- else:
- logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
- 'that you can trust the external codes.')
+ logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
+ 'that you can trust the external codes.')
- # Create a dataset builder
builder_instance = DatasetsWrapperHF.load_dataset_builder(
path=path,
name=name,
@@ -1132,15 +542,13 @@ class DatasetsWrapperHF:
revision=revision,
token=token,
storage_options=storage_options,
- trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
+ trust_remote_code=trust_remote_code,
_require_default_config_name=name is None,
**config_kwargs,
)
- # Note: Only for preview mode
if dataset_info_only:
ret_dict = {}
- # Get dataset config info from python script
if isinstance(path, str) and path.endswith('.py') and os.path.exists(path):
from datasets import get_dataset_config_names
subset_list = get_dataset_config_names(path)
@@ -1161,26 +569,17 @@ class DatasetsWrapperHF:
ret_dict[tmp_config_name] = []
return ret_dict
- # Return iterable dataset in case of streaming
if streaming:
return builder_instance.as_streaming_dataset(split=split)
- # Some datasets are already processed on the HF google storage
- # Don't try downloading from Google storage for the packaged datasets as text, json, csv or pandas
- # try_from_hf_gcs = path not in _PACKAGED_DATASETS_MODULES
-
- # Download and prepare data
builder_instance.download_and_prepare(
download_config=download_config,
download_mode=download_mode,
verification_mode=verification_mode,
num_proc=num_proc,
storage_options=storage_options,
- # base_path=builder_instance.base_path,
- # file_format=builder_instance.name or 'arrow',
)
- # Build dataset for splits
keep_in_memory = (
keep_in_memory if keep_in_memory is not None else is_small_dataset(
builder_instance.info.dataset_size))
@@ -1188,9 +587,7 @@ class DatasetsWrapperHF:
split=split,
verification_mode=verification_mode,
in_memory=keep_in_memory)
- # Rename and cast features to match task schema
if task is not None:
- # To avoid issuing the same warning twice
with warnings.catch_warnings():
warnings.simplefilter('ignore', FutureWarning)
ds = ds.prepare_for_task(task)
@@ -1198,13 +595,12 @@ class DatasetsWrapperHF:
builder_instance._save_infos()
try:
- _api = HubApi()
-
+ api = _get_hub_api()
if is_relative_path(path) and path.count('/') == 1:
_namespace, _dataset_name = path.split('/')
- endpoint = _api.get_endpoint_for_read(
+ endpoint = api.get_endpoint_for_read(
repo_id=path, repo_type=REPO_TYPE_DATASET)
- _api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
+ api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint)
except Exception as e:
logger.warning(f'Could not record download statistics: {e}')
@@ -1249,14 +645,6 @@ class DatasetsWrapperHF:
) if download_config else DownloadConfig()
download_config.storage_options.update(storage_options)
- if trust_remote_code:
- if not _HAS_SCRIPT_LOADING:
- logger.warning('trust_remote_code is ignored: script-based dataset loading '
- 'is no longer supported with datasets>=4.0.')
- else:
- logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
- 'that you can trust the external codes.')
-
dataset_module = DatasetsWrapperHF.dataset_module_factory(
path,
revision=revision,
@@ -1265,12 +653,11 @@ class DatasetsWrapperHF:
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_dir,
- trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
+ trust_remote_code=trust_remote_code,
_require_default_config_name=_require_default_config_name,
_require_custom_configs=bool(config_kwargs),
name=name,
)
- # Get dataset builder class from the processing script
builder_kwargs = dataset_module.builder_kwargs
data_dir = builder_kwargs.pop('data_dir', data_dir)
data_files = builder_kwargs.pop('data_files', data_files)
@@ -1307,7 +694,7 @@ class DatasetsWrapperHF:
features=features,
token=token,
storage_options=storage_options,
- **builder_kwargs, # contains base_path
+ **builder_kwargs,
**config_kwargs,
)
builder_instance._use_legacy_cache_dir_if_possible(dataset_module)
@@ -1353,28 +740,6 @@ class DatasetsWrapperHF:
filename = filename + '.py'
combined_path = os.path.join(path, filename)
- # We have several ways to get a dataset builder:
- #
- # - if path is the name of a packaged dataset module
- # -> use the packaged module (json, csv, etc.)
- #
- # - if os.path.join(path, name) is a local python file
- # -> use the module from the python file
- # - if path is a local directory (but no python file)
- # -> use a packaged module (csv, text etc.) based on content of the directory
- #
- # - if path has one "/" and is dataset repository on the HF hub with a python file
- # -> the module from the python file in the dataset repository
- # - if path has one "/" and is dataset repository on the HF hub without a python file
- # -> use a packaged module (csv, text etc.) based on content of the repository
- if trust_remote_code:
- if not _HAS_SCRIPT_LOADING:
- logger.warning('trust_remote_code is ignored: script-based dataset loading '
- 'is no longer supported with datasets>=4.0.')
- else:
- logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
- 'that you can trust the external codes.')
-
# Try packaged
if path in _PACKAGED_DATASETS_MODULES:
return PackagedDatasetModuleFactory(
@@ -1384,34 +749,40 @@ class DatasetsWrapperHF:
download_config=download_config,
download_mode=download_mode,
).get_module()
- # Try locally with script (requires datasets <4.0)
+ # Try locally with script
elif path.endswith(filename):
if os.path.isfile(path):
- if not _HAS_SCRIPT_LOADING:
- raise RuntimeError(
- f'Script-based dataset loading ({path}) is not supported with datasets>=4.0. '
- 'Please convert the dataset to a script-free format (e.g. Parquet).')
- return LocalDatasetModuleFactoryWithScript(
+ if _HAS_SCRIPT_LOADING:
+ return LocalDatasetModuleFactoryWithScript(
+ path,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
+ return _compat_local_script_module(
path,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
- ).get_module()
+ )
else:
raise FileNotFoundError(
f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
)
elif os.path.isfile(combined_path):
- if not _HAS_SCRIPT_LOADING:
- raise RuntimeError(
- f'Script-based dataset loading ({combined_path}) is not supported with datasets>=4.0. '
- 'Please convert the dataset to a script-free format (e.g. Parquet).')
- return LocalDatasetModuleFactoryWithScript(
+ if _HAS_SCRIPT_LOADING:
+ return LocalDatasetModuleFactoryWithScript(
+ combined_path,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
+ return _compat_local_script_module(
combined_path,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
- ).get_module()
+ )
elif os.path.isdir(path):
return LocalDatasetModuleFactoryWithoutScript(
path,
@@ -1430,15 +801,14 @@ class DatasetsWrapperHF:
token=download_config.token,
timeout=100.0,
)
- except Exception as e: # noqa catch any exception of hf_hub and consider that the dataset doesn't exist
+ except Exception as e: # noqa: broad exception from hf_hub
if isinstance(
e,
( # noqa: E131
- OfflineModeIsEnabled, # noqa: E131
- requests.exceptions.
- ConnectTimeout, # noqa: E131, E261
- requests.exceptions.ConnectionError, # noqa: E131
- ), # noqa: E131
+ OfflineModeIsEnabled,
+ requests.exceptions.ConnectTimeout,
+ requests.exceptions.ConnectionError,
+ ),
):
raise ConnectionError(
f"Couldn't reach '{path}' on the Hub ({type(e).__name__})"
@@ -1469,15 +839,9 @@ class DatasetsWrapperHF:
if filename in [
sibling.rfilename for sibling in dataset_info.siblings
- ]: # contains a dataset script
-
- # TODO
+ ]:
can_load_config_from_parquet_export = False
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
- # If the parquet export is ready (parquet files + info available for the current sha),
- # we can use it instead
- # This fails when the dataset has multiple configs and a default config and
- # the user didn't specify a configuration name (_require_default_config_name=True).
try:
if has_attr_in_class(HubDatasetModuleFactoryWithParquetExport, 'revision'):
return HubDatasetModuleFactoryWithParquetExport(
@@ -1492,35 +856,35 @@ class DatasetsWrapperHF:
except Exception as e:
logger.error(e)
- # Otherwise we must use the dataset script if the user trusts it.
- # Script-based loading was removed in datasets 4.0.
- if not _HAS_SCRIPT_LOADING:
- raise RuntimeError(
- f"Dataset '{path}' contains a loading script but script-based dataset loading "
- 'is not supported with datasets>=4.0. Please convert the dataset to a '
- 'script-free format (e.g. Parquet).')
+ if _HAS_SCRIPT_LOADING:
+ if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
+ return HubDatasetModuleFactoryWithScript(
+ path,
+ revision=revision,
+ download_config=download_config,
+ download_mode=download_mode,
+ dynamic_modules_path=dynamic_modules_path,
+ trust_remote_code=trust_remote_code,
+ ).get_module()
- # To be adapted to the old version of datasets
- if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
return HubDatasetModuleFactoryWithScript(
path,
- revision=revision,
+ commit_hash=commit_hash,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
).get_module()
- return HubDatasetModuleFactoryWithScript(
+ return _compat_hub_script_module(
path,
- commit_hash=commit_hash,
+ revision=revision,
download_config=download_config,
download_mode=download_mode,
dynamic_modules_path=dynamic_modules_path,
trust_remote_code=trust_remote_code,
- ).get_module()
+ )
else:
- # To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithoutScript, 'revision'):
return HubDatasetModuleFactoryWithoutScript(
path,
@@ -1540,18 +904,15 @@ class DatasetsWrapperHF:
download_mode=download_mode,
).get_module()
except Exception as e1:
- # All the attempts failed, before raising the error we should check if the module is already cached
logger.error(f'>> Error loading {path}: {e1}')
try:
- # dynamic_modules_path was removed in datasets 4.0
_cached_factory_kwargs = {'cache_dir': cache_dir}
if _HAS_SCRIPT_LOADING:
_cached_factory_kwargs['dynamic_modules_path'] = dynamic_modules_path
return CachedDatasetModuleFactory(
path, **_cached_factory_kwargs).get_module()
except Exception:
- # If it's not in the cache, then it doesn't exist.
if isinstance(e1, OfflineModeIsEnabled):
raise ConnectionError(
f"Couldn't reach the Hugging Face Hub for dataset '{path}': {e1}"
@@ -1573,79 +934,38 @@ class DatasetsWrapperHF:
f'any data file in the same directory.')
-_hf_fs_open_original = None
-
-
-def _hf_fs_open(self, path, mode='rb', **kwargs):
- """Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API.
-
- The ModelScope tree API may report Size=0 for files. When HfFileSystem
- caches this, AbstractBufferedFile treats the file as empty (0 bytes).
- This wrapper detects size=0 for files opened in read mode and resolves
- the actual size via a HEAD request before creating the file object.
- """
- if mode == 'rb' and 'size' not in kwargs:
- try:
- resolved = self.resolve_path(path)
- resolved_name = resolved.unresolve()
- parent = self._parent(resolved_name)
- cached_size = None
- if parent in self.dircache:
- for entry in self.dircache[parent]:
- if entry['name'] == resolved_name and entry.get('type') == 'file':
- cached_size = entry.get('size', -1)
- break
- if cached_size == 0:
- url = hf_hub_url(
- repo_id=resolved.repo_id,
- revision=resolved.revision,
- filename=resolved.path_in_repo,
- repo_type=resolved.repo_type,
- endpoint=self.endpoint,
- )
- headers = self._api._build_hf_headers()
- resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30)
- if resp.status_code == 200:
- cl = resp.headers.get('Content-Length')
- if cl:
- actual_size = int(cl)
- kwargs['size'] = actual_size
- for entry in self.dircache.get(parent, []):
- if entry['name'] == resolved_name:
- entry['size'] = actual_size
- break
- except Exception:
- pass
- return _hf_fs_open_original(self, path, mode=mode, **kwargs)
-
+# ===================================================================
+# Context manager – load_dataset_with_ctx
+# ===================================================================
@contextlib.contextmanager
def load_dataset_with_ctx(*args, **kwargs):
+ """Context manager that monkey-patches ``datasets`` to use ModelScope.
+
+ All monkey-patches are applied on entry and restored on exit (for
+ non-streaming mode) or kept alive (for streaming mode, where lazy
+ iteration needs the patches to remain active).
+ """
global _hf_fs_open_original
- # Keep the original functions
+ # Save originals
hf_endpoint_origin = config.HF_ENDPOINT
get_from_cache_origin = file_utils.get_from_cache
-
- # Compatible with datasets 2.18.0
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
else DownloadManager._download_single
-
dataset_info_origin = HfApi.dataset_info
list_repo_tree_origin = HfApi.list_repo_tree
get_paths_info_origin = HfApi.get_paths_info
resolve_pattern_origin = data_files.resolve_pattern
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
- # Script-based loading was removed in datasets 4.0
get_module_with_script_origin = (
HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None)
generate_from_dict_origin = features.generate_from_dict
hf_fs_open_origin = HfFileSystem._open
- # Monkey patching with modelscope functions
+ # Apply patches
config.HF_ENDPOINT = get_endpoint()
file_utils.get_from_cache = get_from_cache_ms
- # Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_ms
else:
@@ -1678,7 +998,6 @@ def load_dataset_with_ctx(*args, **kwargs):
file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin
- # Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_origin
else:
diff --git a/requirements/datasets.txt b/requirements/datasets.txt
index 21337757..611ec0d4 100644
--- a/requirements/datasets.txt
+++ b/requirements/datasets.txt
@@ -1,6 +1,6 @@
addict
attrs
-datasets>=4.0.0,<=4.6.1
+datasets>=4.0.0,<=4.8.4
einops
oss2
Pillow
diff --git a/requirements/framework.txt b/requirements/framework.txt
index b0c255dc..e619b305 100644
--- a/requirements/framework.txt
+++ b/requirements/framework.txt
@@ -1,6 +1,6 @@
addict
attrs
-datasets>=4.0.0,<=4.6.1
+datasets>=4.0.0,<=4.8.4
einops
Pillow
python-dateutil>=2.1