diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a8565f16..5e40ec55 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -20,7 +20,8 @@ repos: (?x)^( examples/| modelscope/utils/ast_index_file.py| - modelscope/fileio/format/jsonplus.py + modelscope/fileio/format/jsonplus.py| + modelscope/msdatasets/utils/_module_factories\.py )$ - repo: https://github.com/pre-commit/mirrors-yapf.git rev: v0.30.0 @@ -31,7 +32,8 @@ repos: thirdparty/| examples/| modelscope/utils/ast_index_file.py| - modelscope/fileio/format/jsonplus.py + modelscope/fileio/format/jsonplus.py| + modelscope/msdatasets/utils/_module_factories\.py )$ - repo: https://github.com/pre-commit/pre-commit-hooks.git rev: v3.1.0 diff --git a/modelscope/msdatasets/utils/_compat.py b/modelscope/msdatasets/utils/_compat.py new file mode 100644 index 00000000..05ec9e9b --- /dev/null +++ b/modelscope/msdatasets/utils/_compat.py @@ -0,0 +1,288 @@ +# isort: skip_file +# yapf: disable +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Compatibility shims for datasets>=4.0 script-based dataset loading. + +Script-based dataset loading was removed in datasets 4.0. This module +provides minimal re-implementations of the necessary helpers so that +ModelScope can still load datasets that ship a custom builder .py script. + +When running with datasets<4.0 the real implementations are simply +re-exported from datasets.load / datasets.utils.py_utils. +""" +import importlib +import os +import sys +from pathlib import Path +from typing import List, Optional, Tuple + +from datasets import DownloadMode, config + +# --------------------------------------------------------------------------- +# Try importing script-loading APIs from datasets<4.0 +# --------------------------------------------------------------------------- +try: + from datasets.load import ( + HubDatasetModuleFactoryWithScript, + LocalDatasetModuleFactoryWithScript, + resolve_trust_remote_code, + _get_importable_file_path, + _create_importable_file, + _load_importable_file, + init_dynamic_modules, + files_to_hash, + ) + from datasets.utils.py_utils import get_imports + + _HAS_SCRIPT_LOADING = True +except ImportError: + _HAS_SCRIPT_LOADING = False + HubDatasetModuleFactoryWithScript = None # type: ignore[assignment,misc] + LocalDatasetModuleFactoryWithScript = None # type: ignore[assignment,misc] + +# --------------------------------------------------------------------------- +# Compat implementations (only defined when datasets>=4.0) +# --------------------------------------------------------------------------- +if not _HAS_SCRIPT_LOADING: + import filecmp + import hashlib # noqa: F811 – only imported in this branch + import json as _json + import re + import shutil + from urllib.parse import urlparse + + from datasets.packaged_modules import _hash_python_lines + from datasets.utils.file_utils import url_or_path_join + from datasets.utils.hub import hf_dataset_url # noqa: F401 + from filelock import FileLock + + def _compat_get_imports( + file_path: str) -> List[Tuple[str, str, str, Optional[str]]]: + """Parse a dataset script for import statements (ported from datasets<4.0).""" + with open(file_path, encoding='utf-8') as f: + lines = f.readlines() + + imports: List[Tuple[str, str, str, Optional[str]]] = [] + is_in_docstring = False + for line in lines: + docstr_start_match = re.findall(r'[\s\S]*?"""[\s\S]*?', line) + if len(docstr_start_match) == 1: + is_in_docstring = not is_in_docstring + if is_in_docstring: + continue + match = re.match( + r'^import\s+(\.?)([^\s\.]+)[^#\r\n]*(?:#\s+From:\s+)?([^\r\n]*)', + line, + flags=re.MULTILINE) + if match is None: + match = re.match( + r'^from\s+(\.?)([^\s\.]+)(?:[^\s]*)\s+import\s+[^#\r\n]*(?:#\s+From:\s+)?([^\r\n]*)', + line, + flags=re.MULTILINE) + if match is None: + continue + if match.group(1): + if any(imp[1] == match.group(2) for imp in imports): + continue + if match.group(3): + url_path = match.group(3) + url_path, sub_directory = _compat_convert_github_url( + url_path) + imports.append( + ('external', match.group(2), url_path, sub_directory)) + elif match.group(2): + imports.append( + ('internal', match.group(2), match.group(2), None)) + else: + if match.group(3): + imports.append( + ('library', match.group(2), match.group(3), None)) + else: + imports.append( + ('library', match.group(2), match.group(2), None)) + return imports + + def _compat_convert_github_url(url_path: str) -> Tuple[str, Optional[str]]: + parsed = urlparse(url_path) + sub_directory = None + if parsed.scheme in ('http', 'https', + 's3') and parsed.netloc == 'github.com': + if 'blob' in url_path: + if not url_path.endswith('.py'): + raise ValueError( + f'External import from github at {url_path} should point to a .py file' + ) + url_path = url_path.replace('blob', 'raw') + else: + github_path = parsed.path[1:] + repo_info, branch = ( + github_path.split('/tree/') if '/tree/' in github_path else + (github_path, 'master')) + repo_owner, repo_name = repo_info.split('/') + url_path = f'https://github.com/{repo_owner}/{repo_name}/archive/{branch}.zip' + sub_directory = f'{repo_name}-{branch}' + return url_path, sub_directory + + # -- dynamic module management ---------------------------------------- + + def _compat_init_dynamic_modules( + name: str = config.MODULE_NAME_FOR_DYNAMIC_MODULES, + hf_modules_cache=None, + ) -> str: + hf_modules_cache = str(hf_modules_cache or config.HF_MODULES_CACHE) + if hf_modules_cache not in sys.path: + sys.path.append(hf_modules_cache) + os.makedirs(hf_modules_cache, exist_ok=True) + init_path = os.path.join(hf_modules_cache, '__init__.py') + if not os.path.exists(init_path): + with open(init_path, 'w'): + pass + importlib.invalidate_caches() + dynamic_modules_path = os.path.join(hf_modules_cache, name) + os.makedirs(dynamic_modules_path, exist_ok=True) + init_path2 = os.path.join(dynamic_modules_path, '__init__.py') + if not os.path.exists(init_path2): + with open(init_path2, 'w'): + pass + return dynamic_modules_path + + def _compat_files_to_hash(file_paths) -> str: + to_use_files: list = [] + for fp in file_paths: + if os.path.isdir(fp): + to_use_files.extend(list(Path(fp).rglob('*.[pP][yY]'))) + else: + to_use_files.append(fp) + lines: list = [] + for fp in to_use_files: + with open(fp, encoding='utf-8') as f: + lines.extend(f.readlines()) + return _hash_python_lines(lines) + + # -- importable file management --------------------------------------- + + def _compat_get_importable_file_path( + dynamic_modules_path: str, + module_namespace: str, + subdirectory_name: str, + name: str, + ) -> str: + importable_dir = os.path.join(dynamic_modules_path, module_namespace, + name.replace('/', '--')) + return os.path.join(importable_dir, subdirectory_name, + name.split('/')[-1] + '.py') + + def _compat_copy_script_and_resources( + name: str, + importable_directory_path: str, + subdirectory_name: str, + original_local_path: str, + local_imports: List[Tuple[str, str]], + additional_files: List[Tuple[str, str]], + download_mode, + ) -> str: + importable_subdirectory = os.path.join(importable_directory_path, + subdirectory_name) + importable_file = os.path.join(importable_subdirectory, name + '.py') + lock_path = importable_directory_path + '.lock' + with FileLock(lock_path): + if download_mode == DownloadMode.FORCE_REDOWNLOAD and os.path.exists( + importable_directory_path): + shutil.rmtree(importable_directory_path) + os.makedirs(importable_directory_path, exist_ok=True) + init_fp = os.path.join(importable_directory_path, '__init__.py') + if not os.path.exists(init_fp): + with open(init_fp, 'w'): + pass + os.makedirs(importable_subdirectory, exist_ok=True) + init_fp2 = os.path.join(importable_subdirectory, '__init__.py') + if not os.path.exists(init_fp2): + with open(init_fp2, 'w'): + pass + if not os.path.exists(importable_file): + shutil.copyfile(original_local_path, importable_file) + meta_path = os.path.splitext(importable_file)[0] + '.json' + if not os.path.exists(meta_path): + meta = { + 'original file path': original_local_path, + 'local file path': importable_file + } + with open(meta_path, 'w', encoding='utf-8') as mf: + _json.dump(meta, mf) + for imp_name, imp_path in local_imports: + if os.path.isfile(imp_path): + dest = os.path.join(importable_subdirectory, + imp_name + '.py') + if not os.path.exists(dest): + shutil.copyfile(imp_path, dest) + elif os.path.isdir(imp_path): + dest = os.path.join(importable_subdirectory, imp_name) + if not os.path.exists(dest): + shutil.copytree(imp_path, dest) + else: + raise ImportError(f'Error with local import at {imp_path}') + for file_name, original_path in additional_files: + dest_path = os.path.join(importable_subdirectory, file_name) + if not os.path.exists(dest_path) or not filecmp.cmp( + original_path, dest_path): + shutil.copyfile(original_path, dest_path) + return importable_file + + def _compat_create_importable_file( + local_path: str, + local_imports: List[Tuple[str, str]], + additional_files: List[Tuple[str, str]], + dynamic_modules_path: str, + module_namespace: str, + subdirectory_name: str, + name: str, + download_mode, + ) -> None: + importable_dir = os.path.join(dynamic_modules_path, module_namespace, + name.replace('/', '--')) + Path(importable_dir).mkdir(parents=True, exist_ok=True) + (Path(importable_dir).parent / '__init__.py').touch(exist_ok=True) + _compat_copy_script_and_resources( + name=name.split('/')[-1], + importable_directory_path=importable_dir, + subdirectory_name=subdirectory_name, + original_local_path=local_path, + local_imports=local_imports, + additional_files=additional_files, + download_mode=download_mode, + ) + + def _compat_load_importable_file( + dynamic_modules_path: str, + module_namespace: str, + subdirectory_name: str, + name: str, + ) -> Tuple[str, str]: + module_path = '.'.join([ + os.path.basename(dynamic_modules_path), + module_namespace, + name.replace('/', '--'), + subdirectory_name, + name.split('/')[-1], + ]) + return module_path, subdirectory_name + + # -- trust handling --------------------------------------------------- + + def _compat_resolve_trust_remote_code(trust_remote_code, repo_id: str): + if trust_remote_code is None: + raise ValueError( + f'The repository for {repo_id} contains custom code which must be ' + f'executed to correctly load the dataset. You can inspect the repository ' + f'content at the Hub.\nPlease pass the argument `trust_remote_code=True` ' + f'to allow custom code to be run.') + return trust_remote_code + + # -- Assign compat functions to canonical names ----------------------- + get_imports = _compat_get_imports # noqa: F811 + init_dynamic_modules = _compat_init_dynamic_modules # noqa: F811 + files_to_hash = _compat_files_to_hash # noqa: F811 + resolve_trust_remote_code = _compat_resolve_trust_remote_code # noqa: F811 + _get_importable_file_path = _compat_get_importable_file_path # noqa: F811 + _create_importable_file = _compat_create_importable_file # noqa: F811 + _load_importable_file = _compat_load_importable_file # noqa: F811 diff --git a/modelscope/msdatasets/utils/_module_factories.py b/modelscope/msdatasets/utils/_module_factories.py new file mode 100644 index 00000000..861048ff --- /dev/null +++ b/modelscope/msdatasets/utils/_module_factories.py @@ -0,0 +1,656 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +"""Dataset module factory functions and data file resolution for ModelScope. + +This module provides ModelScope-specific implementations of dataset module +loading (both script-based and script-free) and data file pattern resolution. +These functions are monkey-patched onto the ``datasets`` library internals +by :func:`~hf_datasets_util.load_dataset_with_ctx`. +""" +import importlib +import inspect +import os +from functools import partial +from pathlib import Path +from typing import Dict, List, Optional, Sequence, Tuple, Union + +from datasets import (BuilderConfig, DownloadConfig, DownloadMode, Features, + Version, config, data_files) +from datasets.data_files import ( + FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError, + _get_data_files_patterns, _is_inside_unrequested_special_dir, + _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, + sanitize_patterns) +from datasets.download.streaming_download_manager import ( + _prepare_path_and_storage_options, xbasename, xjoin) +from datasets.exceptions import DataFilesNotFoundError +from datasets.info import DatasetInfosDict +from datasets.load import (BuilderConfigsParameters, DatasetModule, + create_builder_configs_from_metadata_configs, + get_dataset_builder_class, import_main_class, + infer_module_for_data_files) +from datasets.naming import camelcase_to_snakecase +from datasets.packaged_modules import (_MODULE_TO_EXTENSIONS, + _PACKAGED_DATASETS_MODULES) +from datasets.utils.file_utils import (cached_path, is_local_path, + relative_to_absolute_path) +from datasets.utils.metadata import MetadataConfigs +from datasets.utils.track import tracked_str +from fsspec import filesystem +from fsspec.core import _un_chain +from fsspec.utils import stringify_path +from huggingface_hub import DatasetCard, DatasetCardData +from packaging import version + +from modelscope import HubApi +from modelscope.msdatasets.utils._compat import ( + _HAS_SCRIPT_LOADING, _create_importable_file, _get_importable_file_path, + _load_importable_file, files_to_hash, get_imports, init_dynamic_modules, + resolve_trust_remote_code) +from modelscope.utils.constant import (DEFAULT_DATASET_REVISION, + REPO_TYPE_DATASET) +from modelscope.utils.file_utils import is_relative_path +from modelscope.utils.import_utils import has_attr_in_class +from modelscope.utils.logger import get_logger + +# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0 +try: + from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS +except ImportError: + from datasets.load import ALL_ALLOWED_EXTENSIONS + +logger = get_logger() + +# --------------------------------------------------------------------------- +# Shared HubApi instance (avoids creating a new requests.Session per call) +# --------------------------------------------------------------------------- +_hub_api: Optional[HubApi] = None + + +def _get_hub_api() -> HubApi: + global _hub_api + if _hub_api is None: + _hub_api = HubApi(timeout=3 * 60, max_retries=3) + return _hub_api + + +# =================================================================== +# Data file resolution +# =================================================================== + + +def get_fs_token_paths( + urlpath, + storage_options=None, + protocol=None, +): + if isinstance(urlpath, (list, tuple, set)): + if not urlpath: + raise ValueError('empty urlpath sequence') + urlpath0 = stringify_path(list(urlpath)[0]) + else: + urlpath0 = stringify_path(urlpath) + storage_options = storage_options or {} + if protocol: + storage_options['protocol'] = protocol + chain = _un_chain(urlpath0, storage_options or {}) + inkwargs = {} + for i, ch in enumerate(reversed(chain)): + urls, nested_protocol, kw = ch + if i == len(chain) - 1: + inkwargs = dict(**kw, **inkwargs) + continue + inkwargs['target_options'] = dict(**kw, **inkwargs) + inkwargs['target_protocol'] = nested_protocol + inkwargs['fo'] = urls + paths, protocol, _ = chain[0] + fs = filesystem(protocol, **inkwargs) + return fs + + +def _resolve_pattern( + pattern: str, + base_path: str, + allowed_extensions: Optional[List[str]] = None, + download_config: Optional[DownloadConfig] = None, +) -> List[str]: + """Resolve data file paths/URLs from a user-supplied pattern. + + Supports ``*``, ``**``, and fsspec-based remote patterns (e.g. ``hf://``). + Hidden files/directories and ``__pycache__`` are excluded by default. + """ + if is_relative_path(pattern): + pattern = xjoin(base_path, pattern) + elif is_local_path(pattern): + base_path = os.path.splitdrive(pattern)[0] + os.sep + else: + base_path = '' + pattern, storage_options = _prepare_path_and_storage_options( + pattern, download_config=download_config) + fs = get_fs_token_paths(pattern, storage_options=storage_options) + fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker + fs_pattern = pattern.split('::')[0].split('://')[-1] + files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)} + protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] + protocol_prefix = protocol + '://' if protocol != 'file' else '' + glob_kwargs = {} + if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'): + glob_kwargs['expand_info'] = False + + try: + tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs) + except FileNotFoundError: + raise DataFilesNotFoundError(f"Unable to find '{pattern}'") + + matched_paths = [ + filepath if filepath.startswith(protocol_prefix) else protocol_prefix + + filepath for filepath, info in tmp_file_paths.items() + if info['type'] == 'file' and ( + xbasename(filepath) not in files_to_ignore) + and not _is_inside_unrequested_special_dir( + os.path.relpath(filepath, fs_base_path), + os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504 + not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504 + os.path.relpath(filepath, fs_base_path), + os.path.relpath(fs_pattern, fs_base_path)) + ] + if allowed_extensions is not None: + out = [ + filepath for filepath in matched_paths + if any('.' + suffix in allowed_extensions + for suffix in xbasename(filepath).split('.')[1:]) + ] + if len(out) < len(matched_paths): + invalid_matched_files = list(set(matched_paths) - set(out)) + logger.info( + f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: " + f'{invalid_matched_files}') + else: + out = matched_paths + if not out: + error_msg = f"Unable to find '{pattern}'" + if allowed_extensions is not None: + error_msg += f' with any supported extension {list(allowed_extensions)}' + raise FileNotFoundError(error_msg) + return out + + +def _get_data_patterns( + base_path: str, + download_config: Optional[DownloadConfig] = None +) -> Dict[str, List[str]]: + """Get data file patterns for a dataset directory. + + Tries ``SPLIT_PATTERN_SHARDED`` first, then falls back to + ``ALL_DEFAULT_PATTERNS``. + """ + resolver = partial( + _resolve_pattern, base_path=base_path, download_config=download_config) + try: + return _get_data_files_patterns(resolver) + except FileNotFoundError: + raise EmptyDatasetError( + f"The directory at {base_path} doesn't contain any data files" + ) from None + + +# =================================================================== +# Repository file download helper +# =================================================================== + + +def _download_repo_file( + repo_id: str, + path_in_repo: str, + download_config: DownloadConfig, + revision: str, +) -> str: + """Download a single file from a ModelScope dataset repository.""" + api = _get_hub_api() + _namespace, _dataset_name = repo_id.split('/') + endpoint = api.get_endpoint_for_read( + repo_id=repo_id, repo_type=REPO_TYPE_DATASET) + if download_config and download_config.download_desc is None: + download_config.download_desc = f'Downloading [{path_in_repo}]' + try: + url_or_filename = api.get_dataset_file_url( + file_name=path_in_repo, + dataset_name=_dataset_name, + namespace=_namespace, + revision=revision, + extension_filter=False, + endpoint=endpoint, + ) + repo_file_path = cached_path( + url_or_filename=url_or_filename, download_config=download_config) + except FileNotFoundError as e: + repo_file_path = '' + logger.error(e) + return repo_file_path + + +# =================================================================== +# Additional modules download (for script-based datasets) +# =================================================================== + + +def _download_additional_modules( + name: str, + dataset_name: str, + namespace: str, + revision: str, + imports: Tuple[str, str, str, str], + download_config: Optional[DownloadConfig], + trust_remote_code: Optional[bool] = False, +) -> List[Tuple[str, str]]: + """Download additional modules referenced by a dataset builder script. + + Parses the import list produced by ``get_imports`` and downloads any + internal (relative) or external modules. Library imports are validated + but not downloaded. + """ + local_imports: List[Tuple[str, str]] = [] + library_imports: List[Tuple[str, str]] = [] + + has_remote_code = any( + import_type in ('internal', 'external') + for import_type, _, _, _ in imports) + if has_remote_code and not trust_remote_code: + raise ValueError( + f'Loading {name} requires executing code from the repository. ' + 'This is disabled by default for security reasons. ' + 'If you trust the authors of this dataset, you can enable it with ' + '`trust_remote_code=True`.') + + api = _get_hub_api() + download_config = download_config.copy() + if download_config.download_desc is None: + download_config.download_desc = 'Downloading extra modules' + + for import_type, import_name, import_path, sub_directory in imports: + if import_type == 'library': + library_imports.append((import_name, import_path)) + continue + if import_name == name: + raise ValueError( + f'Error in the {name} script, importing relative {import_name} module ' + f'but {import_name} is the name of the script. ' + f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' " + f'comment pointing to the original relative import file path.') + if import_type == 'internal': + file_name = import_path + '.py' + url_or_filename = api.get_dataset_file_url( + file_name=file_name, + dataset_name=dataset_name, + namespace=namespace, + revision=revision, + ) + elif import_type == 'external': + url_or_filename = import_path + else: + raise ValueError('Wrong import_type') + + local_import_path = cached_path( + url_or_filename, download_config=download_config) + if sub_directory is not None: + local_import_path = os.path.join(local_import_path, sub_directory) + local_imports.append((import_name, local_import_path)) + + # Validate library imports + needs_to_be_installed = {} + for library_import_name, library_import_path in library_imports: + try: + importlib.import_module(library_import_name) + except ImportError: + if library_import_name not in needs_to_be_installed or library_import_path != library_import_name: + needs_to_be_installed[ + library_import_name] = library_import_path + if needs_to_be_installed: + _dependencies_str = 'dependencies' if len( + needs_to_be_installed) > 1 else 'dependency' + _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it' + if 'sklearn' in needs_to_be_installed: + needs_to_be_installed['sklearn'] = 'scikit-learn' + if 'Bio' in needs_to_be_installed: + needs_to_be_installed['Bio'] = 'biopython' + raise ImportError( + f'To be able to use {name}, you need to install the following {_dependencies_str}: ' + f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install " + f"{' '.join(needs_to_be_installed.values())}' for instance.") + return local_imports + + +# =================================================================== +# Module factory: script-based (Hub) +# =================================================================== + + +def _load_script_module( + repo_id: str, + revision: str, + download_config: DownloadConfig, + download_mode=None, + dynamic_modules_path: Optional[str] = None, + trust_remote_code: Optional[bool] = None, +) -> DatasetModule: + """Shared implementation for loading a dataset module from a Hub .py script. + + Used by both ``get_module_with_script`` (monkey-patch for datasets<4.0) and + ``_compat_hub_script_module`` (compat shim for datasets>=4.0). + """ + _namespace, _dataset_name = repo_id.split('/') + script_file_name = f'{_dataset_name}.py' + + local_script_path = _download_repo_file( + repo_id=repo_id, + path_in_repo=script_file_name, + download_config=download_config, + revision=revision, + ) + if not local_script_path: + raise FileNotFoundError( + f'Cannot find {script_file_name} in {repo_id} at revision {revision}.' + ) + + dataset_readme_path = _download_repo_file( + repo_id=repo_id, + path_in_repo='README.md', + download_config=download_config, + revision=revision, + ) + + imports = get_imports(local_script_path) + local_imports = _download_additional_modules( + name=repo_id, + dataset_name=_dataset_name, + namespace=_namespace, + revision=revision, + imports=imports, + download_config=download_config, + trust_remote_code=trust_remote_code, + ) + + additional_files = [] + if dataset_readme_path: + additional_files.append( + (config.REPOCARD_FILENAME, dataset_readme_path)) + + dynamic_modules_path = dynamic_modules_path or init_dynamic_modules() + hash_val = files_to_hash([local_script_path] + + [loc[1] for loc in local_imports]) + importable_file_path = _get_importable_file_path( + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=repo_id, + ) + if not os.path.exists(importable_file_path): + trust = resolve_trust_remote_code( + trust_remote_code=trust_remote_code, repo_id=repo_id) + if trust: + logger.warning( + f'Use trust_remote_code=True. Will invoke codes from {repo_id}. ' + 'Please make sure that you can trust the external codes.') + _create_importable_file( + local_path=local_script_path, + local_imports=local_imports, + additional_files=additional_files, + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=repo_id, + download_mode=download_mode, + ) + else: + raise ValueError( + f'Loading {repo_id} requires executing the dataset script in that' + ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then' + ' set the option `trust_remote_code=True` to remove this error.' + ) + module_path, hash_val = _load_importable_file( + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=repo_id, + ) + importlib.invalidate_caches() + + api = _get_hub_api() + builder_kwargs = { + 'base_path': api.get_file_base_path(repo_id=repo_id), + 'repo_id': repo_id, + } + return DatasetModule(module_path, hash_val, builder_kwargs) + + +def get_module_with_script(self) -> DatasetModule: + """Monkey-patch target for ``HubDatasetModuleFactoryWithScript.get_module`` (datasets<4.0).""" + repo_id: str = self.name + revision = self.download_config.storage_options.get( + 'revision', None) or DEFAULT_DATASET_REVISION + return _load_script_module( + repo_id=repo_id, + revision=revision, + download_config=self.download_config, + download_mode=self.download_mode, + dynamic_modules_path=self.dynamic_modules_path + if self.dynamic_modules_path else None, + trust_remote_code=self.trust_remote_code, + ) + + +def _compat_hub_script_module( + path, + revision=None, + download_config=None, + download_mode=None, + dynamic_modules_path=None, + trust_remote_code=None, +) -> DatasetModule: + """Load a dataset module from a Hub repo .py script (compat for datasets>=4.0).""" + return _load_script_module( + repo_id=path, + revision=revision or DEFAULT_DATASET_REVISION, + download_config=download_config or DownloadConfig(), + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ) + + +# =================================================================== +# Module factory: script-based (local) +# =================================================================== + + +def _compat_local_script_module( + path, + download_mode=None, + dynamic_modules_path=None, + trust_remote_code=None, +) -> DatasetModule: + """Load a dataset module from a local .py script (compat for datasets>=4.0).""" + local_path = path + name = Path(path).stem + + local_imports: List[Tuple[str, str]] = [] + imports = get_imports(local_path) + for import_type, import_name, import_path, sub_directory in imports: + if import_type == 'library': + continue + if import_type == 'internal': + rel_path = os.path.join( + os.path.dirname(local_path), import_path + '.py') + if os.path.isfile(rel_path): + local_imports.append((import_name, rel_path)) + elif os.path.isdir( + os.path.join(os.path.dirname(local_path), import_path)): + local_imports.append( + (import_name, + os.path.join(os.path.dirname(local_path), import_path))) + elif import_type == 'external': + dl_config = DownloadConfig() + dl_config.download_desc = 'Downloading extra modules' + local_import_path = cached_path( + import_path, download_config=dl_config) + if sub_directory is not None: + local_import_path = os.path.join(local_import_path, + sub_directory) + local_imports.append((import_name, local_import_path)) + + dynamic_modules_path = dynamic_modules_path or init_dynamic_modules() + hash_val = files_to_hash([local_path] + [loc[1] for loc in local_imports]) + importable_file_path = _get_importable_file_path( + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=name, + ) + if not os.path.exists(importable_file_path): + trust = resolve_trust_remote_code(trust_remote_code, name) + if trust: + _create_importable_file( + local_path=local_path, + local_imports=local_imports, + additional_files=[], + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=name, + download_mode=download_mode, + ) + else: + raise ValueError( + f'Loading {name} requires executing the dataset script. ' + 'Set `trust_remote_code=True` to allow this.') + module_path, hash_val = _load_importable_file( + dynamic_modules_path=dynamic_modules_path, + module_namespace='datasets', + subdirectory_name=hash_val, + name=name, + ) + importlib.invalidate_caches() + builder_kwargs = { + 'base_path': str(Path(path).resolve().parent), + } + return DatasetModule(module_path, hash_val, builder_kwargs) + + +# =================================================================== +# Module factory: without script (Hub) +# =================================================================== + + +def get_module_without_script(self) -> DatasetModule: + """Monkey-patch target for ``HubDatasetModuleFactoryWithoutScript.get_module``.""" + revision = self.download_config.storage_options.get( + 'revision', None) or DEFAULT_DATASET_REVISION + base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip( + '/') + + repo_id: str = self.name + download_config = self.download_config.copy() + + dataset_readme_path = _download_repo_file( + repo_id=repo_id, + path_in_repo='README.md', + download_config=download_config, + revision=revision, + ) + + dataset_card_data = DatasetCard.load( + Path(dataset_readme_path + )).data if dataset_readme_path else DatasetCardData() + subset_name: str = download_config.storage_options.get('name', None) + + metadata_configs = MetadataConfigs.from_dataset_card_data( + dataset_card_data) + dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data) + + if self.data_files is not None: + patterns = sanitize_patterns(self.data_files) + elif metadata_configs and 'data_files' in next( + iter(metadata_configs.values())): + if subset_name is not None: + subset_data_files = metadata_configs[subset_name]['data_files'] + else: + subset_data_files = next(iter( + metadata_configs.values()))['data_files'] + patterns = sanitize_patterns(subset_data_files) + else: + patterns = _get_data_patterns( + base_path, download_config=self.download_config) + + data_files_dict = DataFilesDict.from_patterns( + patterns, + base_path=base_path, + allowed_extensions=ALL_ALLOWED_EXTENSIONS, + download_config=self.download_config, + ) + module_name, default_builder_kwargs = infer_module_for_data_files( + data_files=data_files_dict, + path=self.name, + download_config=self.download_config, + ) + + if hasattr(data_files_dict, 'filter'): + data_files_dict = data_files_dict.filter( + extensions=_MODULE_TO_EXTENSIONS[module_name]) + else: + data_files_dict = data_files_dict.filter_extensions( + _MODULE_TO_EXTENSIONS[module_name]) + + module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] + + if metadata_configs: + supports_metadata = module_name in {'imagefolder', 'audiofolder'} + create_builder_signature = inspect.signature( + create_builder_configs_from_metadata_configs) + in_args = { + 'module_path': module_path, + 'metadata_configs': metadata_configs, + 'base_path': base_path, + 'default_builder_kwargs': default_builder_kwargs, + 'download_config': self.download_config, + } + if 'supports_metadata' in create_builder_signature.parameters: + in_args['supports_metadata'] = supports_metadata + builder_configs, default_config_name = create_builder_configs_from_metadata_configs( + **in_args) + else: + builder_configs: List[BuilderConfig] = [ + import_main_class(module_path).BUILDER_CONFIG_CLASS( + data_files=data_files_dict, + **default_builder_kwargs, + ) + ] + default_config_name = None + + api = _get_hub_api() + endpoint = api.get_endpoint_for_read( + repo_id=repo_id, repo_type=REPO_TYPE_DATASET) + + builder_kwargs = { + 'base_path': + api.get_file_base_path(repo_id=repo_id, endpoint=endpoint), + 'repo_id': self.name, + 'dataset_name': camelcase_to_snakecase(Path(self.name).name), + 'data_files': data_files_dict, + } + download_config = self.download_config.copy() + if download_config.download_desc is None: + download_config.download_desc = 'Downloading metadata' + + if default_config_name is None and len(dataset_infos) == 1: + default_config_name = next(iter(dataset_infos)) + + return DatasetModule( + module_path, + revision, + builder_kwargs, + dataset_infos=dataset_infos, + builder_configs_parameters=BuilderConfigsParameters( + metadata_configs=metadata_configs, + builder_configs=builder_configs, + default_config_name=default_config_name, + ), + ) diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py index bde8965b..1290cb26 100644 --- a/modelscope/msdatasets/utils/hf_datasets_util.py +++ b/modelscope/msdatasets/utils/hf_datasets_util.py @@ -1,27 +1,32 @@ # noqa: isort:skip_file, yapf: disable # Copyright (c) Alibaba, Inc. and its affiliates. # Copyright 2020 The HuggingFace Datasets Authors and the TensorFlow Datasets Authors. -import importlib +"""ModelScope dataset loading orchestration. + +This module provides :class:`DatasetsWrapperHF` and the +:func:`load_dataset_with_ctx` context manager that monkey-patch the +HuggingFace ``datasets`` library to work with the ModelScope Hub. + +Sub-modules: + _compat – backward-compat shims for datasets>=4.0 script loading + _module_factories – dataset module factory functions & data-file resolution +""" import contextlib -import inspect import os import warnings -from dataclasses import dataclass, field, fields -from functools import partial +from dataclasses import fields from pathlib import Path -from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar +from typing import Any, Dict, Iterable, List, Literal, Mapping, Optional, Sequence, Tuple, Union from urllib.parse import urlencode import requests -from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict, +from datasets import (Dataset, DatasetBuilder, DatasetDict, DownloadConfig, DownloadManager, DownloadMode, Features, IterableDataset, IterableDatasetDict, Split, VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf) -# In datasets 4.0+, Sequence was replaced by List as a feature type. -# Use List as the base for ListMs when available, fall back to Sequence for <4.0. try: from datasets import List as DatasetList except ImportError: @@ -29,73 +34,25 @@ except ImportError: from datasets.features import features from datasets.features.features import _FEATURE_TYPES -from datasets.data_files import ( - FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError, - _get_data_files_patterns, _is_inside_unrequested_special_dir, - _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir, sanitize_patterns) -from datasets.download.streaming_download_manager import ( - _prepare_path_and_storage_options, xbasename, xjoin) +from datasets.data_files import DataFilesDict, EmptyDatasetError from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError -from datasets.info import DatasetInfosDict from datasets.load import ( - BuilderConfigsParameters, CachedDatasetModuleFactory, DatasetModule, HubDatasetModuleFactoryWithParquetExport, PackagedDatasetModuleFactory, - create_builder_configs_from_metadata_configs, get_dataset_builder_class, - import_main_class, infer_module_for_data_files) - -# To compatible with datasets 4.0+ -try: - from datasets.load import ( - HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript, - LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript) -except ImportError: - from datasets.load import ( - HubDatasetModuleFactoryWithoutScript, - LocalDatasetModuleFactoryWithoutScript) - -# Script-based dataset loading was removed in datasets 4.0. -# These APIs are conditionally imported for backward compatibility with <4.0. -try: - from datasets.load import ( - HubDatasetModuleFactoryWithScript, - LocalDatasetModuleFactoryWithScript, - resolve_trust_remote_code, - _get_importable_file_path, _create_importable_file, - _load_importable_file, init_dynamic_modules, - files_to_hash) - from datasets.utils.py_utils import get_imports - _HAS_SCRIPT_LOADING = True -except ImportError: - _HAS_SCRIPT_LOADING = False - -from datasets.naming import camelcase_to_snakecase -from datasets.packaged_modules import (_EXTENSION_TO_MODULE, - _MODULE_TO_EXTENSIONS, - _PACKAGED_DATASETS_MODULES) -# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0 -try: - from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS -except ImportError: - from datasets.load import ALL_ALLOWED_EXTENSIONS + get_dataset_builder_class) +from datasets.packaged_modules import (_EXTENSION_TO_MODULE, _PACKAGED_DATASETS_MODULES) from datasets.utils import file_utils from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled, - cached_path, is_local_path, - relative_to_absolute_path) + cached_path, relative_to_absolute_path) from datasets.utils.info_utils import is_small_dataset -from datasets.utils.metadata import MetadataConfigs from datasets.utils.track import tracked_str -from fsspec import filesystem -from fsspec.core import _un_chain -from fsspec.utils import stringify_path -from huggingface_hub import (DatasetCard, DatasetCardData, hf_hub_url) +from huggingface_hub import hf_hub_url from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder from huggingface_hub.hf_file_system import HfFileSystem -from packaging import version from modelscope import HubApi from modelscope.hub.utils.utils import get_endpoint @@ -106,9 +63,41 @@ from modelscope.utils.import_utils import has_attr_in_class from modelscope.utils.file_utils import is_relative_path from modelscope.utils.logger import get_logger +# -- Compat layer ----------------------------------------------------------- +from modelscope.msdatasets.utils._compat import ( + _HAS_SCRIPT_LOADING, + HubDatasetModuleFactoryWithScript, + LocalDatasetModuleFactoryWithScript, +) + +# -- Module factories -------------------------------------------------------- +from modelscope.msdatasets.utils._module_factories import ( + _resolve_pattern, + _download_repo_file, + get_module_without_script, + get_module_with_script, + _compat_local_script_module, + _compat_hub_script_module, + _get_hub_api, +) + +# Compatible with datasets 4.0+ (class name changed) +try: + from datasets.load import ( + HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript, + LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript) +except ImportError: + from datasets.load import ( + HubDatasetModuleFactoryWithoutScript, + LocalDatasetModuleFactoryWithoutScript) + logger = get_logger() +# =================================================================== +# Type definitions +# =================================================================== + ExpandDatasetProperty_T = Literal[ 'author', 'cardData', @@ -129,33 +118,27 @@ ExpandDatasetProperty_T = Literal[ ] -# Patch datasets features +# =================================================================== +# Feature patching (generate_from_dict_ms) +# =================================================================== + _NativeList = DatasetList if DatasetList is not None else SequenceHf def generate_from_dict_ms(obj: Any): """Regenerate the nested feature object from a deserialized dict. - We use the '_type' fields to get the dataclass name to load. - generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax - to define features from deserialized JSON dictionaries. This function is used in particular when deserializing - a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to - :meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any - mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes - that :class:`Value` automatically performs. + This is a ModelScope-patched version of ``features.generate_from_dict`` + that handles backward compatibility for legacy ``Sequence`` types in + datasets 4.0+ where ``Sequence`` is no longer a registered feature type. """ - # Nested structures: we allow dict, list/tuples, sequences if isinstance(obj, list): return [generate_from_dict_ms(value) for value in obj] - # Otherwise we have a dict or a dataclass if '_type' not in obj or isinstance(obj['_type'], dict): return {key: generate_from_dict_ms(value) for key, value in obj.items()} obj = dict(obj) _type = obj.pop('_type') - # Handle legacy 'Sequence' type for backward compatibility. - # In datasets 4.0+, Sequence is a utility function (not a feature type), - # so it may not be registered in _FEATURE_TYPES. if _type == 'Sequence': feature = obj.pop('feature') length = obj.get('length', -1) @@ -169,7 +152,6 @@ def generate_from_dict_ms(obj: Any): if class_type == LargeList: feature = obj.pop('feature') return LargeList(generate_from_dict_ms(feature), **obj) - # Handle the native List type (datasets 4.0+) as well as Sequence-based if _NativeList is not None and (class_type is _NativeList or issubclass(class_type, _NativeList)): feature = obj.pop('feature') return _NativeList(generate_from_dict_ms(feature), **obj) @@ -178,17 +160,22 @@ def generate_from_dict_ms(obj: Any): return class_type(**{k: v for k, v in obj.items() if k in field_names}) +# =================================================================== +# Download monkey-patch (_download_ms) +# =================================================================== + def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str: + """ModelScope replacement for ``DownloadManager._download``. + + Rewrites relative paths and ``hf://`` URLs to ModelScope API endpoints. + """ url_or_filename = str(url_or_filename) if url_or_filename.startswith('hf://'): - # hf:// URLs (e.g. hf://datasets/{owner}/{name}@{revision}/{file_path}) hf_path = url_or_filename[len('hf://'):] - # Strip leading resource type prefix (e.g. "datasets/") for _prefix in ('datasets/', 'models/'): if hf_path.startswith(_prefix): hf_path = hf_path[len(_prefix):] break - # Extract revision and file_path from "{owner}/{name}@{revision}/{file_path}" if '@' in hf_path: at_idx = hf_path.index('@') after_at = hf_path[at_idx + 1:] @@ -203,14 +190,12 @@ def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> parts = hf_path.split('/', 2) revision = DEFAULT_DATASET_REVISION file_path = parts[2] if len(parts) > 2 else '' - params = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path}) - url_or_filename = self._base_path + params + params_str = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path}) + url_or_filename = self._base_path + params_str elif is_relative_path(url_or_filename): revision = DEFAULT_DATASET_REVISION - # Note: make sure the FilePath is the last param - params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename} - params: str = urlencode(params) - url_or_filename = self._base_path + params + params_str = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}) + url_or_filename = self._base_path + params_str out = cached_path(url_or_filename, download_config=download_config) out = tracked_str(out) @@ -218,6 +203,10 @@ def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> return out +# =================================================================== +# HfApi monkey-patches (dataset_info, list_repo_tree, get_paths_info) +# =================================================================== + def _dataset_info( self, repo_id: str, @@ -228,45 +217,7 @@ def _dataset_info( token: Optional[Union[bool, str]] = None, expand: Optional[List[ExpandDatasetProperty_T]] = None, ) -> HfDatasetInfo: - """ - Get info on one specific dataset on huggingface.co. - - Dataset can be private if you pass an acceptable token. - - Args: - repo_id (`str`): - A namespace (user or an organization) and a repo name separated - by a `/`. - revision (`str`, *optional*): - The revision of the dataset repository from which to get the - information. - timeout (`float`, *optional*): - Whether to set a timeout for the request to the Hub. - files_metadata (`bool`, *optional*): - Whether or not to retrieve metadata for files in the repository - (size, LFS metadata, etc). Defaults to `False`. - token (`bool` or `str`, *optional*): - A valid authentication token (see https://huggingface.co/settings/token). - If `None` or `True` and machine is logged in (through `huggingface-cli login` - or [`~huggingface_hub.login`]), token will be retrieved from the cache. - If `False`, token is not sent in the request header. - - Returns: - [`hf_api.DatasetInfo`]: The dataset repository information. - - - - Raises the following errors: - - - [`~utils.RepositoryNotFoundError`] - If the repository to download from cannot be found. This may be because it doesn't exist, - or because it is set to `private` and you do not have access. - - [`~utils.RevisionNotFoundError`] - If the revision to download from cannot be found. - - - """ - # Note: refer to `_list_repo_tree()`, for patching `HfApi.list_repo_tree` + """ModelScope replacement for ``HfApi.dataset_info``.""" repo_info_iter = self.list_repo_tree( repo_id=repo_id, path_in_repo='/', @@ -277,22 +228,21 @@ def _dataset_info( repo_type=REPO_TYPE_DATASET, ) - # Update data_info - data_info = dict({}) - data_info['id'] = repo_id - data_info['private'] = False - data_info['author'] = repo_id.split('/')[0] if repo_id else None - data_info['sha'] = revision - data_info['lastModified'] = None - data_info['gated'] = False - data_info['disabled'] = False - data_info['downloads'] = 0 - data_info['likes'] = 0 - data_info['tags'] = [] - data_info['cardData'] = [] - data_info['createdAt'] = None + data_info = { + 'id': repo_id, + 'private': False, + 'author': repo_id.split('/')[0] if repo_id else None, + 'sha': revision, + 'lastModified': None, + 'gated': False, + 'disabled': False, + 'downloads': 0, + 'likes': 0, + 'tags': [], + 'cardData': [], + 'createdAt': None, + } - # e.g. {'rfilename': 'xxx', 'blobId': 'xxx', 'size': 0, 'lfs': {'size': 0, 'sha256': 'xxx', 'pointerSize': 0}} data_siblings = [] for info_item in repo_info_iter: if isinstance(info_item, RepoFile): @@ -308,6 +258,8 @@ def _dataset_info( return HfDatasetInfo(**data_info) +# -- Repo tree cache --------------------------------------------------------- + _repo_tree_cache: Dict[tuple, List[Union[RepoFile, RepoFolder]]] = {} @@ -350,7 +302,7 @@ def _list_repo_tree( repo_type: Optional[str] = None, token: Optional[Union[bool, str]] = None, ) -> Iterable[Union[RepoFile, RepoFolder]]: - + """ModelScope replacement for ``HfApi.list_repo_tree``.""" revision = revision or DEFAULT_DATASET_REVISION normalized_path = path_in_repo or '/' cache_key = (repo_id, revision, normalized_path, recursive) @@ -366,24 +318,21 @@ def _list_repo_tree( yield from derived return - _api = HubApi(timeout=3 * 60, max_retries=3) - endpoint = _api.get_endpoint_for_read( + api = _get_hub_api() + endpoint = api.get_endpoint_for_read( repo_id=repo_id, repo_type=REPO_TYPE_DATASET) _owner, _dataset_name = repo_id.split('/') - dataset_hub_id, _ = _api.get_dataset_id_and_type( + dataset_hub_id, _ = api.get_dataset_id_and_type( dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint) results: List[Union[RepoFile, RepoFolder]] = [] page_number = 1 - # Larger page_size reduces the number of HTTP round-trips for big datasets. - # Termination uses `not dataset_files` (empty page) so it is safe even if - # the server silently caps the actual page size to a smaller value. page_size = 500 max_pages = 10000 while page_number <= max_pages: try: - dataset_files = _api.get_dataset_files( + dataset_files = api.get_dataset_files( repo_id=repo_id, revision=revision, root_path=normalized_path, @@ -428,19 +377,22 @@ def _get_paths_info( repo_type: Optional[str] = None, token: Optional[Union[bool, str]] = None, ) -> List[Union[RepoFile, RepoFolder]]: - + """ModelScope replacement for ``HfApi.get_paths_info``.""" revision = revision or DEFAULT_DATASET_REVISION if isinstance(paths, str): paths = [paths] paths_set = set(paths) - # Search within any cached tree data (recursive root is the most comprehensive) - root_key = (repo_id, revision, '/', True) - root_cached = _repo_tree_cache.get(root_key) - if root_cached is not None: - matched = [item for item in root_cached if item.path in paths_set] + # Check all available caches for matching paths + for cache_key, cached_items in _repo_tree_cache.items(): + if cache_key[0] != repo_id or cache_key[1] != revision: + continue + matched = [item for item in cached_items if item.path in paths_set] if matched: return matched + # Recursive root cache is authoritative – if paths not found, they don't exist + if cache_key == (repo_id, revision, '/', True): + return [] repo_info_iter = self.list_repo_tree( repo_id=repo_id, @@ -451,598 +403,62 @@ def _get_paths_info( token=token, ) - return [item_info for item_info in repo_info_iter] + return [item for item in repo_info_iter if item.path in paths_set] -def _download_repo_file(repo_id: str, path_in_repo: str, download_config: DownloadConfig, revision: str): - _api = HubApi() - _namespace, _dataset_name = repo_id.split('/') - endpoint = _api.get_endpoint_for_read( - repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - if download_config and download_config.download_desc is None: - download_config.download_desc = f'Downloading [{path_in_repo}]' - try: - url_or_filename = _api.get_dataset_file_url( - file_name=path_in_repo, - dataset_name=_dataset_name, - namespace=_namespace, - revision=revision, - extension_filter=False, - endpoint=endpoint - ) - repo_file_path = cached_path( - url_or_filename=url_or_filename, download_config=download_config) - except FileNotFoundError as e: - repo_file_path = '' - logger.error(e) +# =================================================================== +# HfFileSystem patch (_hf_fs_open) +# =================================================================== - return repo_file_path +_hf_fs_open_original = None -def get_fs_token_paths( - urlpath, - storage_options=None, - protocol=None, -): - if isinstance(urlpath, (list, tuple, set)): - if not urlpath: - raise ValueError('empty urlpath sequence') - urlpath0 = stringify_path(list(urlpath)[0]) - else: - urlpath0 = stringify_path(urlpath) - storage_options = storage_options or {} - if protocol: - storage_options['protocol'] = protocol - chain = _un_chain(urlpath0, storage_options or {}) - inkwargs = {} - # Reverse iterate the chain, creating a nested target_* structure - for i, ch in enumerate(reversed(chain)): - urls, nested_protocol, kw = ch - if i == len(chain) - 1: - inkwargs = dict(**kw, **inkwargs) - continue - inkwargs['target_options'] = dict(**kw, **inkwargs) - inkwargs['target_protocol'] = nested_protocol - inkwargs['fo'] = urls - paths, protocol, _ = chain[0] - fs = filesystem(protocol, **inkwargs) +def _hf_fs_open(self, path, mode='rb', **kwargs): + """Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API. - return fs - - -def _resolve_pattern( - pattern: str, - base_path: str, - allowed_extensions: Optional[List[str]] = None, - download_config: Optional[DownloadConfig] = None, -) -> List[str]: + The ModelScope tree API may report Size=0 for files. When HfFileSystem + caches this, AbstractBufferedFile treats the file as empty (0 bytes). + This wrapper detects size=0 for files opened in read mode and resolves + the actual size via a HEAD request before creating the file object. """ - Resolve the paths and URLs of the data files from the pattern passed by the user. - - You can use patterns to resolve multiple local files. Here are a few examples: - - *.csv to match all the CSV files at the first level - - **.csv to match all the CSV files at any level - - data/* to match all the files inside "data" - - data/** to match all the files inside "data" and its subdirectories - - The patterns are resolved using the fsspec glob. - - glob.glob, Path.glob, Path.match or fnmatch do not support ** with a prefix/suffix other than a forward slash /. - For instance, this means **.json is the same as *.json. On the contrary, the fsspec glob has no limits regarding the ** prefix/suffix, # noqa: E501 - resulting in **.json being equivalent to **/*.json. - - More generally: - - '*' matches any character except a forward-slash (to match just the file or directory name) - - '**' matches any character including a forward-slash / - - Hidden files and directories (i.e. whose names start with a dot) are ignored, unless they are explicitly requested. - The same applies to special directories that start with a double underscore like "__pycache__". - You can still include one if the pattern explicitly mentions it: - - to include a hidden file: "*/.hidden.txt" or "*/.*" - - to include a hidden directory: ".hidden/*" or ".*/*" - - to include a special directory: "__special__/*" or "__*/*" - - Example:: - - >>> from datasets.data_files import resolve_pattern - >>> base_path = "." - >>> resolve_pattern("docs/**/*.py", base_path) - [/Users/mariosasko/Desktop/projects/datasets/docs/source/_config.py'] - - Args: - pattern (str): Unix pattern or paths or URLs of the data files to resolve. - The paths can be absolute or relative to base_path. - Remote filesystems using fsspec are supported, e.g. with the hf:// protocol. - base_path (str): Base path to use when resolving relative paths. - allowed_extensions (Optional[list], optional): White-list of file extensions to use. Defaults to None (all extensions). - For example: allowed_extensions=[".csv", ".json", ".txt", ".parquet"] - Returns: - List[str]: List of paths or URLs to the local or remote files that match the patterns. - """ - if is_relative_path(pattern): - pattern = xjoin(base_path, pattern) - elif is_local_path(pattern): - base_path = os.path.splitdrive(pattern)[0] + os.sep - else: - base_path = '' - # storage_options: {'hf': {'token': None, 'endpoint': 'https://huggingface.co'}} - pattern, storage_options = _prepare_path_and_storage_options( - pattern, download_config=download_config) - fs = get_fs_token_paths(pattern, storage_options=storage_options) - fs_base_path = base_path.split('::')[0].split('://')[-1] or fs.root_marker - fs_pattern = pattern.split('::')[0].split('://')[-1] - files_to_ignore = set(FILES_TO_IGNORE) - {xbasename(pattern)} - protocol = fs.protocol if isinstance(fs.protocol, str) else fs.protocol[0] - protocol_prefix = protocol + '://' if protocol != 'file' else '' - glob_kwargs = {} - if protocol == 'hf' and config.HF_HUB_VERSION >= version.parse('0.20.0'): - # 10 times faster glob with detail=True (ignores costly info like lastCommit) - glob_kwargs['expand_info'] = False - - try: - tmp_file_paths = fs.glob(pattern, detail=True, **glob_kwargs) - except FileNotFoundError: - raise DataFilesNotFoundError(f"Unable to find '{pattern}'") - - matched_paths = [ - filepath if filepath.startswith(protocol_prefix) else protocol_prefix - + filepath for filepath, info in tmp_file_paths.items() - if info['type'] == 'file' and ( - xbasename(filepath) not in files_to_ignore) - and not _is_inside_unrequested_special_dir( - os.path.relpath(filepath, fs_base_path), - os.path.relpath(fs_pattern, fs_base_path)) and # noqa: W504 - not _is_unrequested_hidden_file_or_is_inside_unrequested_hidden_dir( # noqa: W504 - os.path.relpath(filepath, fs_base_path), - os.path.relpath(fs_pattern, fs_base_path)) - ] # ignore .ipynb and __pycache__, but keep /../ - if allowed_extensions is not None: - out = [ - filepath for filepath in matched_paths - if any('.' + suffix in allowed_extensions - for suffix in xbasename(filepath).split('.')[1:]) - ] - if len(out) < len(matched_paths): - invalid_matched_files = list(set(matched_paths) - set(out)) - logger.info( - f"Some files matched the pattern '{pattern}' but don't have valid data file extensions: " - f'{invalid_matched_files}') - else: - out = matched_paths - if not out: - error_msg = f"Unable to find '{pattern}'" - if allowed_extensions is not None: - error_msg += f' with any supported extension {list(allowed_extensions)}' - raise FileNotFoundError(error_msg) - return out - - -def _get_data_patterns( - base_path: str, - download_config: Optional[DownloadConfig] = None) -> Dict[str, - List[str]]: - """ - Get the default pattern from a directory testing all the supported patterns. - The first patterns to return a non-empty list of data files is returned. - - Some examples of supported patterns: - - Input: - - my_dataset_repository/ - ├── README.md - └── dataset.csv - - Output: - - {"train": ["**"]} - - Input: - - my_dataset_repository/ - ├── README.md - ├── train.csv - └── test.csv - - my_dataset_repository/ - ├── README.md - └── data/ - ├── train.csv - └── test.csv - - my_dataset_repository/ - ├── README.md - ├── train_0.csv - ├── train_1.csv - ├── train_2.csv - ├── train_3.csv - ├── test_0.csv - └── test_1.csv - - Output: - - {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**', - 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'], - 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**', - 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]} - - Input: - - my_dataset_repository/ - ├── README.md - └── data/ - ├── train/ - │ ├── shard_0.csv - │ ├── shard_1.csv - │ ├── shard_2.csv - │ └── shard_3.csv - └── test/ - ├── shard_0.csv - └── shard_1.csv - - Output: - - {'train': ['train[-._ 0-9/]**', '**/*[-._ 0-9/]train[-._ 0-9/]**', - 'training[-._ 0-9/]**', '**/*[-._ 0-9/]training[-._ 0-9/]**'], - 'test': ['test[-._ 0-9/]**', '**/*[-._ 0-9/]test[-._ 0-9/]**', - 'testing[-._ 0-9/]**', '**/*[-._ 0-9/]testing[-._ 0-9/]**', ...]} - - Input: - - my_dataset_repository/ - ├── README.md - └── data/ - ├── train-00000-of-00003.csv - ├── train-00001-of-00003.csv - ├── train-00002-of-00003.csv - ├── test-00000-of-00001.csv - ├── random-00000-of-00003.csv - ├── random-00001-of-00003.csv - └── random-00002-of-00003.csv - - Output: - - {'train': ['data/train-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], - 'test': ['data/test-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*'], - 'random': ['data/random-[0-9][0-9][0-9][0-9][0-9]-of-[0-9][0-9][0-9][0-9][0-9]*.*']} - - In order, it first tests if SPLIT_PATTERN_SHARDED works, otherwise it tests the patterns in ALL_DEFAULT_PATTERNS. - """ - resolver = partial( - _resolve_pattern, base_path=base_path, download_config=download_config) - try: - return _get_data_files_patterns(resolver) - except FileNotFoundError: - raise EmptyDatasetError( - f"The directory at {base_path} doesn't contain any data files" - ) from None - - -def get_module_without_script(self) -> DatasetModule: - - # hfh_dataset_info = HfApi(config.HF_ENDPOINT).dataset_info( - # self.name, - # revision=self.revision, - # token=self.download_config.token, - # timeout=100.0, - # ) - # even if metadata_configs is not None (which means that we will resolve files for each config later) - # we cannot skip resolving all files because we need to infer module name by files extensions - # revision = hfh_dataset_info.sha # fix the revision in case there are new commits in the meantime - revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION - base_path = f"hf://datasets/{self.name}@{revision}/{self.data_dir or ''}".rstrip( - '/') - - repo_id: str = self.name - download_config = self.download_config.copy() - - dataset_readme_path = _download_repo_file( - repo_id=repo_id, - path_in_repo='README.md', - download_config=download_config, - revision=revision) - - dataset_card_data = DatasetCard.load(Path(dataset_readme_path)).data if dataset_readme_path else DatasetCardData() - subset_name: str = download_config.storage_options.get('name', None) - - metadata_configs = MetadataConfigs.from_dataset_card_data( - dataset_card_data) - dataset_infos = DatasetInfosDict.from_dataset_card_data(dataset_card_data) - # we need a set of data files to find which dataset builder to use - # because we need to infer module name by files extensions - if self.data_files is not None: - patterns = sanitize_patterns(self.data_files) - elif metadata_configs and 'data_files' in next( - iter(metadata_configs.values())): - - if subset_name is not None: - subset_data_files = metadata_configs[subset_name]['data_files'] - else: - subset_data_files = next(iter(metadata_configs.values()))['data_files'] - patterns = sanitize_patterns(subset_data_files) - else: - patterns = _get_data_patterns( - base_path, download_config=self.download_config) - - data_files = DataFilesDict.from_patterns( - patterns, - base_path=base_path, - allowed_extensions=ALL_ALLOWED_EXTENSIONS, - download_config=self.download_config, - ) - module_name, default_builder_kwargs = infer_module_for_data_files( - data_files=data_files, - path=self.name, - download_config=self.download_config, - ) - - if hasattr(data_files, 'filter'): - data_files = data_files.filter(extensions=_MODULE_TO_EXTENSIONS[module_name]) - else: - data_files = data_files.filter_extensions(_MODULE_TO_EXTENSIONS[module_name]) - - module_path, _ = _PACKAGED_DATASETS_MODULES[module_name] - - if metadata_configs: - - supports_metadata = module_name in {'imagefolder', 'audiofolder'} - create_builder_signature = inspect.signature(create_builder_configs_from_metadata_configs) - in_args = { - 'module_path': module_path, - 'metadata_configs': metadata_configs, - 'base_path': base_path, - 'default_builder_kwargs': default_builder_kwargs, - 'download_config': self.download_config, - } - if 'supports_metadata' in create_builder_signature.parameters: - in_args['supports_metadata'] = supports_metadata - - builder_configs, default_config_name = create_builder_configs_from_metadata_configs(**in_args) - else: - builder_configs: List[BuilderConfig] = [ - import_main_class(module_path).BUILDER_CONFIG_CLASS( - data_files=data_files, - **default_builder_kwargs, - ) - ] - default_config_name = None - _api = HubApi() - endpoint = _api.get_endpoint_for_read( - repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - - builder_kwargs = { - # "base_path": hf_hub_url(self.name, "", revision=revision).rstrip("/"), - 'base_path': - HubApi().get_file_base_path(repo_id=repo_id, endpoint=endpoint), - 'repo_id': - self.name, - 'dataset_name': - camelcase_to_snakecase(Path(self.name).name), - 'data_files': data_files, - } - download_config = self.download_config.copy() - if download_config.download_desc is None: - download_config.download_desc = 'Downloading metadata' - - # Note: `dataset_infos.json` is deprecated and can cause an error during loading if it exists - - if default_config_name is None and len(dataset_infos) == 1: - default_config_name = next(iter(dataset_infos)) - - hash = revision - return DatasetModule( - module_path, - hash, - builder_kwargs, - dataset_infos=dataset_infos, - builder_configs_parameters=BuilderConfigsParameters( - metadata_configs=metadata_configs, - builder_configs=builder_configs, - default_config_name=default_config_name, - ), - ) - - -def _download_additional_modules( - name: str, - dataset_name: str, - namespace: str, - revision: str, - imports: Tuple[str, str, str, str], - download_config: Optional[DownloadConfig], - trust_remote_code: Optional[bool] = False, -) -> List[Tuple[str, str]]: - """ - Download additional module for a module .py at URL (or local path) /.py - The imports must have been parsed first using ``get_imports``. - - If some modules need to be installed with pip, an error is raised showing how to install them. - This function return the list of downloaded modules as tuples (import_name, module_file_path). - - The downloaded modules can then be moved into an importable directory - with ``_copy_script_and_other_resources_in_importable_dir``. - """ - local_imports = [] - library_imports = [] - - # Check if we need to execute remote code - has_remote_code = any( - import_type in ('internal', 'external') - for import_type, _, _, _ in imports - ) - - if has_remote_code and not trust_remote_code: - raise ValueError( - f'Loading {name} requires executing code from the repository. ' - 'This is disabled by default for security reasons. ' - 'If you trust the authors of this dataset, you can enable it with ' - '`trust_remote_code=True`.' - ) - - download_config = download_config.copy() - if download_config.download_desc is None: - download_config.download_desc = 'Downloading extra modules' - for import_type, import_name, import_path, sub_directory in imports: - if import_type == 'library': - library_imports.append((import_name, import_path)) # Import from a library - continue - - if import_name == name: - raise ValueError( - f'Error in the {name} script, importing relative {import_name} module ' - f'but {import_name} is the name of the script. ' - f"Please change relative import {import_name} to another name and add a '# From: URL_OR_PATH' " - f'comment pointing to the original relative import file path.' - ) - if import_type == 'internal': - _api = HubApi() - # url_or_filename = url_or_path_join(base_path, import_path + ".py") - file_name = import_path + '.py' - url_or_filename = _api.get_dataset_file_url(file_name=file_name, - dataset_name=dataset_name, - namespace=namespace, - revision=revision,) - elif import_type == 'external': - url_or_filename = import_path - else: - raise ValueError('Wrong import_type') - - local_import_path = cached_path( - url_or_filename, - download_config=download_config, - ) - if sub_directory is not None: - local_import_path = os.path.join(local_import_path, sub_directory) - local_imports.append((import_name, local_import_path)) - - # Check library imports - needs_to_be_installed = {} - for library_import_name, library_import_path in library_imports: + if mode == 'rb' and 'size' not in kwargs: try: - lib = importlib.import_module(library_import_name) # noqa F841 - except ImportError: - if library_import_name not in needs_to_be_installed or library_import_path != library_import_name: - needs_to_be_installed[library_import_name] = library_import_path - if needs_to_be_installed: - _dependencies_str = 'dependencies' if len(needs_to_be_installed) > 1 else 'dependency' - _them_str = 'them' if len(needs_to_be_installed) > 1 else 'it' - if 'sklearn' in needs_to_be_installed.keys(): - needs_to_be_installed['sklearn'] = 'scikit-learn' - if 'Bio' in needs_to_be_installed.keys(): - needs_to_be_installed['Bio'] = 'biopython' - raise ImportError( - f'To be able to use {name}, you need to install the following {_dependencies_str}: ' - f"{', '.join(needs_to_be_installed)}.\nPlease install {_them_str} using 'pip install " - f"{' '.join(needs_to_be_installed.values())}' for instance." - ) - return local_imports + resolved = self.resolve_path(path) + resolved_name = resolved.unresolve() + parent = self._parent(resolved_name) + cached_size = None + if parent in self.dircache: + for entry in self.dircache[parent]: + if entry['name'] == resolved_name and entry.get('type') == 'file': + cached_size = entry.get('size', -1) + break + if cached_size == 0: + url = hf_hub_url( + repo_id=resolved.repo_id, + revision=resolved.revision, + filename=resolved.path_in_repo, + repo_type=resolved.repo_type, + endpoint=self.endpoint, + ) + headers = self._api._build_hf_headers() + resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30) + if resp.status_code == 200: + cl = resp.headers.get('Content-Length') + if cl: + actual_size = int(cl) + kwargs['size'] = actual_size + for entry in self.dircache.get(parent, []): + if entry['name'] == resolved_name: + entry['size'] = actual_size + break + except Exception: + pass + return _hf_fs_open_original(self, path, mode=mode, **kwargs) -def get_module_with_script(self) -> DatasetModule: - if not _HAS_SCRIPT_LOADING: - raise RuntimeError( - 'Script-based dataset loading is not supported with datasets>=4.0. ' - 'Please convert the dataset to a script-free format (e.g. Parquet).') - - repo_id: str = self.name - _namespace, _dataset_name = repo_id.split('/') - revision = self.download_config.storage_options.get('revision', None) or DEFAULT_DATASET_REVISION - - script_file_name = f'{_dataset_name}.py' - local_script_path = _download_repo_file( - repo_id=repo_id, - path_in_repo=script_file_name, - download_config=self.download_config, - revision=revision, - ) - if not local_script_path: - raise FileNotFoundError( - f'Cannot find {script_file_name} in {repo_id} at revision {revision}. ' - f'Please create {script_file_name} in the repo.' - ) - - dataset_infos_path = None - # try: - # dataset_infos_url: str = _api.get_dataset_file_url( - # file_name='dataset_infos.json', - # dataset_name=_dataset_name, - # namespace=_namespace, - # revision=self.revision, - # extension_filter=False, - # ) - # dataset_infos_path = cached_path( - # url_or_filename=dataset_infos_url, download_config=self.download_config) - # except Exception as e: - # logger.info(f'Cannot find dataset_infos.json: {e}') - # dataset_infos_path = None - - dataset_readme_path = _download_repo_file( - repo_id=repo_id, - path_in_repo='README.md', - download_config=self.download_config, - revision=revision - ) - - imports = get_imports(local_script_path) - local_imports = _download_additional_modules( - name=repo_id, - dataset_name=_dataset_name, - namespace=_namespace, - revision=revision, - imports=imports, - download_config=self.download_config, - trust_remote_code=self.trust_remote_code, - ) - additional_files = [] - if dataset_infos_path: - additional_files.append((config.DATASETDICT_INFOS_FILENAME, dataset_infos_path)) - if dataset_readme_path: - additional_files.append((config.REPOCARD_FILENAME, dataset_readme_path)) - # copy the script and the files in an importable directory - dynamic_modules_path = self.dynamic_modules_path if self.dynamic_modules_path else init_dynamic_modules() - hash = files_to_hash([local_script_path] + [loc[1] for loc in local_imports]) - importable_file_path = _get_importable_file_path( - dynamic_modules_path=dynamic_modules_path, - module_namespace='datasets', - subdirectory_name=hash, - name=repo_id, - ) - if not os.path.exists(importable_file_path): - trust_remote_code = resolve_trust_remote_code(trust_remote_code=self.trust_remote_code, repo_id=self.name) - if trust_remote_code: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {repo_id}. Please make sure that ' - 'you can trust the external codes.') - _create_importable_file( - local_path=local_script_path, - local_imports=local_imports, - additional_files=additional_files, - dynamic_modules_path=dynamic_modules_path, - module_namespace='datasets', - subdirectory_name=hash, - name=repo_id, - download_mode=self.download_mode, - ) - else: - raise ValueError( - f'Loading {repo_id} requires you to execute the dataset script in that' - ' repo on your local machine. Make sure you have read the code there to avoid malicious use, then' - ' set the option `trust_remote_code=True` to remove this error.' - ) - module_path, hash = _load_importable_file( - dynamic_modules_path=dynamic_modules_path, - module_namespace='datasets', - subdirectory_name=hash, - name=repo_id, - ) - # make the new module to be noticed by the import system - importlib.invalidate_caches() - builder_kwargs = { - # "base_path": hf_hub_url(self.name, "", revision=self.revision).rstrip("/"), - 'base_path': HubApi().get_file_base_path(repo_id=repo_id), - 'repo_id': repo_id, - } - return DatasetModule(module_path, hash, builder_kwargs) - +# =================================================================== +# DatasetsWrapperHF +# =================================================================== class DatasetsWrapperHF: @@ -1093,8 +509,7 @@ class DatasetsWrapperHF: raise ValueError( f"Empty 'data_files': '{data_files}'. It should be either non-empty or None (default)." ) - if Path(path, config.DATASET_STATE_JSON_FILENAME).exists( - ): + if Path(path, config.DATASET_STATE_JSON_FILENAME).exists(): raise ValueError( 'You are trying to load a dataset that was saved using `save_to_disk`. ' 'Please use `load_from_disk` instead.') @@ -1112,14 +527,9 @@ class DatasetsWrapperHF: ) if not save_infos else VerificationMode.ALL_CHECKS) if trust_remote_code: - if not _HAS_SCRIPT_LOADING: - logger.warning('trust_remote_code is ignored: script-based dataset loading ' - 'is no longer supported with datasets>=4.0.') - else: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.') + logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' + 'that you can trust the external codes.') - # Create a dataset builder builder_instance = DatasetsWrapperHF.load_dataset_builder( path=path, name=name, @@ -1132,15 +542,13 @@ class DatasetsWrapperHF: revision=revision, token=token, storage_options=storage_options, - trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None, + trust_remote_code=trust_remote_code, _require_default_config_name=name is None, **config_kwargs, ) - # Note: Only for preview mode if dataset_info_only: ret_dict = {} - # Get dataset config info from python script if isinstance(path, str) and path.endswith('.py') and os.path.exists(path): from datasets import get_dataset_config_names subset_list = get_dataset_config_names(path) @@ -1161,26 +569,17 @@ class DatasetsWrapperHF: ret_dict[tmp_config_name] = [] return ret_dict - # Return iterable dataset in case of streaming if streaming: return builder_instance.as_streaming_dataset(split=split) - # Some datasets are already processed on the HF google storage - # Don't try downloading from Google storage for the packaged datasets as text, json, csv or pandas - # try_from_hf_gcs = path not in _PACKAGED_DATASETS_MODULES - - # Download and prepare data builder_instance.download_and_prepare( download_config=download_config, download_mode=download_mode, verification_mode=verification_mode, num_proc=num_proc, storage_options=storage_options, - # base_path=builder_instance.base_path, - # file_format=builder_instance.name or 'arrow', ) - # Build dataset for splits keep_in_memory = ( keep_in_memory if keep_in_memory is not None else is_small_dataset( builder_instance.info.dataset_size)) @@ -1188,9 +587,7 @@ class DatasetsWrapperHF: split=split, verification_mode=verification_mode, in_memory=keep_in_memory) - # Rename and cast features to match task schema if task is not None: - # To avoid issuing the same warning twice with warnings.catch_warnings(): warnings.simplefilter('ignore', FutureWarning) ds = ds.prepare_for_task(task) @@ -1198,13 +595,12 @@ class DatasetsWrapperHF: builder_instance._save_infos() try: - _api = HubApi() - + api = _get_hub_api() if is_relative_path(path) and path.count('/') == 1: _namespace, _dataset_name = path.split('/') - endpoint = _api.get_endpoint_for_read( + endpoint = api.get_endpoint_for_read( repo_id=path, repo_type=REPO_TYPE_DATASET) - _api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint) + api.dataset_download_statistics(dataset_name=_dataset_name, namespace=_namespace, endpoint=endpoint) except Exception as e: logger.warning(f'Could not record download statistics: {e}') @@ -1249,14 +645,6 @@ class DatasetsWrapperHF: ) if download_config else DownloadConfig() download_config.storage_options.update(storage_options) - if trust_remote_code: - if not _HAS_SCRIPT_LOADING: - logger.warning('trust_remote_code is ignored: script-based dataset loading ' - 'is no longer supported with datasets>=4.0.') - else: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.') - dataset_module = DatasetsWrapperHF.dataset_module_factory( path, revision=revision, @@ -1265,12 +653,11 @@ class DatasetsWrapperHF: data_dir=data_dir, data_files=data_files, cache_dir=cache_dir, - trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None, + trust_remote_code=trust_remote_code, _require_default_config_name=_require_default_config_name, _require_custom_configs=bool(config_kwargs), name=name, ) - # Get dataset builder class from the processing script builder_kwargs = dataset_module.builder_kwargs data_dir = builder_kwargs.pop('data_dir', data_dir) data_files = builder_kwargs.pop('data_files', data_files) @@ -1307,7 +694,7 @@ class DatasetsWrapperHF: features=features, token=token, storage_options=storage_options, - **builder_kwargs, # contains base_path + **builder_kwargs, **config_kwargs, ) builder_instance._use_legacy_cache_dir_if_possible(dataset_module) @@ -1353,28 +740,6 @@ class DatasetsWrapperHF: filename = filename + '.py' combined_path = os.path.join(path, filename) - # We have several ways to get a dataset builder: - # - # - if path is the name of a packaged dataset module - # -> use the packaged module (json, csv, etc.) - # - # - if os.path.join(path, name) is a local python file - # -> use the module from the python file - # - if path is a local directory (but no python file) - # -> use a packaged module (csv, text etc.) based on content of the directory - # - # - if path has one "/" and is dataset repository on the HF hub with a python file - # -> the module from the python file in the dataset repository - # - if path has one "/" and is dataset repository on the HF hub without a python file - # -> use a packaged module (csv, text etc.) based on content of the repository - if trust_remote_code: - if not _HAS_SCRIPT_LOADING: - logger.warning('trust_remote_code is ignored: script-based dataset loading ' - 'is no longer supported with datasets>=4.0.') - else: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.') - # Try packaged if path in _PACKAGED_DATASETS_MODULES: return PackagedDatasetModuleFactory( @@ -1384,34 +749,40 @@ class DatasetsWrapperHF: download_config=download_config, download_mode=download_mode, ).get_module() - # Try locally with script (requires datasets <4.0) + # Try locally with script elif path.endswith(filename): if os.path.isfile(path): - if not _HAS_SCRIPT_LOADING: - raise RuntimeError( - f'Script-based dataset loading ({path}) is not supported with datasets>=4.0. ' - 'Please convert the dataset to a script-free format (e.g. Parquet).') - return LocalDatasetModuleFactoryWithScript( + if _HAS_SCRIPT_LOADING: + return LocalDatasetModuleFactoryWithScript( + path, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ).get_module() + return _compat_local_script_module( path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path, trust_remote_code=trust_remote_code, - ).get_module() + ) else: raise FileNotFoundError( f"Couldn't find a dataset script at {relative_to_absolute_path(path)}" ) elif os.path.isfile(combined_path): - if not _HAS_SCRIPT_LOADING: - raise RuntimeError( - f'Script-based dataset loading ({combined_path}) is not supported with datasets>=4.0. ' - 'Please convert the dataset to a script-free format (e.g. Parquet).') - return LocalDatasetModuleFactoryWithScript( + if _HAS_SCRIPT_LOADING: + return LocalDatasetModuleFactoryWithScript( + combined_path, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ).get_module() + return _compat_local_script_module( combined_path, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path, trust_remote_code=trust_remote_code, - ).get_module() + ) elif os.path.isdir(path): return LocalDatasetModuleFactoryWithoutScript( path, @@ -1430,15 +801,14 @@ class DatasetsWrapperHF: token=download_config.token, timeout=100.0, ) - except Exception as e: # noqa catch any exception of hf_hub and consider that the dataset doesn't exist + except Exception as e: # noqa: broad exception from hf_hub if isinstance( e, ( # noqa: E131 - OfflineModeIsEnabled, # noqa: E131 - requests.exceptions. - ConnectTimeout, # noqa: E131, E261 - requests.exceptions.ConnectionError, # noqa: E131 - ), # noqa: E131 + OfflineModeIsEnabled, + requests.exceptions.ConnectTimeout, + requests.exceptions.ConnectionError, + ), ): raise ConnectionError( f"Couldn't reach '{path}' on the Hub ({type(e).__name__})" @@ -1469,15 +839,9 @@ class DatasetsWrapperHF: if filename in [ sibling.rfilename for sibling in dataset_info.siblings - ]: # contains a dataset script - - # TODO + ]: can_load_config_from_parquet_export = False if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: - # If the parquet export is ready (parquet files + info available for the current sha), - # we can use it instead - # This fails when the dataset has multiple configs and a default config and - # the user didn't specify a configuration name (_require_default_config_name=True). try: if has_attr_in_class(HubDatasetModuleFactoryWithParquetExport, 'revision'): return HubDatasetModuleFactoryWithParquetExport( @@ -1492,35 +856,35 @@ class DatasetsWrapperHF: except Exception as e: logger.error(e) - # Otherwise we must use the dataset script if the user trusts it. - # Script-based loading was removed in datasets 4.0. - if not _HAS_SCRIPT_LOADING: - raise RuntimeError( - f"Dataset '{path}' contains a loading script but script-based dataset loading " - 'is not supported with datasets>=4.0. Please convert the dataset to a ' - 'script-free format (e.g. Parquet).') + if _HAS_SCRIPT_LOADING: + if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'): + return HubDatasetModuleFactoryWithScript( + path, + revision=revision, + download_config=download_config, + download_mode=download_mode, + dynamic_modules_path=dynamic_modules_path, + trust_remote_code=trust_remote_code, + ).get_module() - # To be adapted to the old version of datasets - if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'): return HubDatasetModuleFactoryWithScript( path, - revision=revision, + commit_hash=commit_hash, download_config=download_config, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path, trust_remote_code=trust_remote_code, ).get_module() - return HubDatasetModuleFactoryWithScript( + return _compat_hub_script_module( path, - commit_hash=commit_hash, + revision=revision, download_config=download_config, download_mode=download_mode, dynamic_modules_path=dynamic_modules_path, trust_remote_code=trust_remote_code, - ).get_module() + ) else: - # To be adapted to the old version of datasets if has_attr_in_class(HubDatasetModuleFactoryWithoutScript, 'revision'): return HubDatasetModuleFactoryWithoutScript( path, @@ -1540,18 +904,15 @@ class DatasetsWrapperHF: download_mode=download_mode, ).get_module() except Exception as e1: - # All the attempts failed, before raising the error we should check if the module is already cached logger.error(f'>> Error loading {path}: {e1}') try: - # dynamic_modules_path was removed in datasets 4.0 _cached_factory_kwargs = {'cache_dir': cache_dir} if _HAS_SCRIPT_LOADING: _cached_factory_kwargs['dynamic_modules_path'] = dynamic_modules_path return CachedDatasetModuleFactory( path, **_cached_factory_kwargs).get_module() except Exception: - # If it's not in the cache, then it doesn't exist. if isinstance(e1, OfflineModeIsEnabled): raise ConnectionError( f"Couldn't reach the Hugging Face Hub for dataset '{path}': {e1}" @@ -1573,79 +934,38 @@ class DatasetsWrapperHF: f'any data file in the same directory.') -_hf_fs_open_original = None - - -def _hf_fs_open(self, path, mode='rb', **kwargs): - """Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API. - - The ModelScope tree API may report Size=0 for files. When HfFileSystem - caches this, AbstractBufferedFile treats the file as empty (0 bytes). - This wrapper detects size=0 for files opened in read mode and resolves - the actual size via a HEAD request before creating the file object. - """ - if mode == 'rb' and 'size' not in kwargs: - try: - resolved = self.resolve_path(path) - resolved_name = resolved.unresolve() - parent = self._parent(resolved_name) - cached_size = None - if parent in self.dircache: - for entry in self.dircache[parent]: - if entry['name'] == resolved_name and entry.get('type') == 'file': - cached_size = entry.get('size', -1) - break - if cached_size == 0: - url = hf_hub_url( - repo_id=resolved.repo_id, - revision=resolved.revision, - filename=resolved.path_in_repo, - repo_type=resolved.repo_type, - endpoint=self.endpoint, - ) - headers = self._api._build_hf_headers() - resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30) - if resp.status_code == 200: - cl = resp.headers.get('Content-Length') - if cl: - actual_size = int(cl) - kwargs['size'] = actual_size - for entry in self.dircache.get(parent, []): - if entry['name'] == resolved_name: - entry['size'] = actual_size - break - except Exception: - pass - return _hf_fs_open_original(self, path, mode=mode, **kwargs) - +# =================================================================== +# Context manager – load_dataset_with_ctx +# =================================================================== @contextlib.contextmanager def load_dataset_with_ctx(*args, **kwargs): + """Context manager that monkey-patches ``datasets`` to use ModelScope. + + All monkey-patches are applied on entry and restored on exit (for + non-streaming mode) or kept alive (for streaming mode, where lazy + iteration needs the patches to remain active). + """ global _hf_fs_open_original - # Keep the original functions + # Save originals hf_endpoint_origin = config.HF_ENDPOINT get_from_cache_origin = file_utils.get_from_cache - - # Compatible with datasets 2.18.0 _download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \ else DownloadManager._download_single - dataset_info_origin = HfApi.dataset_info list_repo_tree_origin = HfApi.list_repo_tree get_paths_info_origin = HfApi.get_paths_info resolve_pattern_origin = data_files.resolve_pattern get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module - # Script-based loading was removed in datasets 4.0 get_module_with_script_origin = ( HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None) generate_from_dict_origin = features.generate_from_dict hf_fs_open_origin = HfFileSystem._open - # Monkey patching with modelscope functions + # Apply patches config.HF_ENDPOINT = get_endpoint() file_utils.get_from_cache = get_from_cache_ms - # Compatible with datasets 2.18.0 if hasattr(DownloadManager, '_download'): DownloadManager._download = _download_ms else: @@ -1678,7 +998,6 @@ def load_dataset_with_ctx(*args, **kwargs): file_utils.get_from_cache = get_from_cache_origin features.generate_from_dict = generate_from_dict_origin - # Compatible with datasets 2.18.0 if hasattr(DownloadManager, '_download'): DownloadManager._download = _download_origin else: diff --git a/requirements/datasets.txt b/requirements/datasets.txt index 21337757..611ec0d4 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=4.0.0,<=4.6.1 +datasets>=4.0.0,<=4.8.4 einops oss2 Pillow diff --git a/requirements/framework.txt b/requirements/framework.txt index b0c255dc..e619b305 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=4.0.0,<=4.6.1 +datasets>=4.0.0,<=4.8.4 einops Pillow python-dateutil>=2.1