From 3c3b78f02938f253c0e81270ac39b89b52dc6afa Mon Sep 17 00:00:00 2001 From: "Xingjun.Wang" Date: Mon, 9 Mar 2026 17:35:31 +0800 Subject: [PATCH] [Feat] dataset module refactor (#1623) --- modelscope/hub/api.py | 131 +++---- modelscope/hub/file_download.py | 8 +- modelscope/hub/snapshot_download.py | 7 +- .../msdatasets/data_loader/data_loader.py | 1 - .../data_loader/data_loader_manager.py | 2 - .../msdatasets/download/dataset_builder.py | 2 +- .../msdatasets/meta/data_meta_manager.py | 2 +- modelscope/msdatasets/ms_dataset.py | 9 +- .../msdatasets/utils/hf_datasets_util.py | 337 ++++++++++++++---- requirements/datasets.txt | 2 +- requirements/framework.txt | 2 +- 11 files changed, 342 insertions(+), 161 deletions(-) diff --git a/modelscope/hub/api.py b/modelscope/hub/api.py index 354167a6..4be160bf 100644 --- a/modelscope/hub/api.py +++ b/modelscope/hub/api.py @@ -11,7 +11,6 @@ import platform import re import shutil import tempfile -import time import uuid import warnings from collections import defaultdict @@ -1403,6 +1402,8 @@ class HubApi: raise_for_http_status(r) raise_on_error(r.json()) + _dataset_id_type_cache: dict = {} + def get_dataset_id_and_type(self, dataset_name: str, namespace: str, @@ -1411,6 +1412,10 @@ class HubApi: """ Get the dataset id and type. """ if not endpoint: endpoint = self.endpoint + cache_key = (namespace, dataset_name, endpoint) + cached = HubApi._dataset_id_type_cache.get(cache_key) + if cached is not None: + return cached datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}' cookies = self.get_cookies(access_token=token) r = self.session.get(datahub_url, cookies=cookies) @@ -1418,6 +1423,7 @@ class HubApi: datahub_raise_on_error(datahub_url, resp, r) dataset_id = resp['Data']['Id'] dataset_type = resp['Data']['Type'] + HubApi._dataset_id_type_cache[cache_key] = (dataset_id, dataset_type) return dataset_id, dataset_type def list_repo_tree(self, @@ -1526,7 +1532,8 @@ class HubApi: page_number: int = 1, page_size: int = 100, endpoint: Optional[str] = None, - token: Optional[str] = None): + token: Optional[str] = None, + dataset_hub_id: Optional[str] = None): """ Get the dataset files. @@ -1539,19 +1546,23 @@ class HubApi: page_size (int): The number of items per page. Defaults to 100. endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class. token (Optional[str]): The access token. + dataset_hub_id (Optional[str]): Pre-fetched dataset hub id. When provided, + skips the internal ``get_dataset_id_and_type`` lookup. Useful in pagination + loops to avoid redundant API calls per page. Returns: List: The response containing the dataset repository tree information. e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...] """ - if is_relative_path(repo_id) and repo_id.count('/') == 1: - _owner, _dataset_name = repo_id.split('/') - else: - raise ValueError(f'Invalid repo_id: {repo_id} !') + if dataset_hub_id is None: + if is_relative_path(repo_id) and repo_id.count('/') == 1: + _owner, _dataset_name = repo_id.split('/') + else: + raise ValueError(f'Invalid repo_id: {repo_id} !') - dataset_hub_id, dataset_type = self.get_dataset_id_and_type( - dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token) + dataset_hub_id, _ = self.get_dataset_id_and_type( + dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token) if not endpoint: endpoint = self.endpoint @@ -1569,7 +1580,10 @@ class HubApi: resp = r.json() datahub_raise_on_error(datahub_url, resp, r) - return resp['Data']['Files'] + data = resp.get('Data') + if data is None: + return [] + return data.get('Files') or [] def get_dataset( self, @@ -2100,11 +2114,9 @@ class HubApi: repo_type: Optional[str] = REPO_TYPE_MODEL, revision: Optional[str] = DEFAULT_REPOSITORY_REVISION, endpoint: Optional[str] = None, - max_retries: int = 3, - timeout: int = 180, ) -> CommitInfo: """ - Create a commit on the ModelScope Hub with retry mechanism. + Create a commit on the ModelScope Hub. Args: repo_id (str): The repo id in the format of `owner_name/repo_name`. @@ -2117,14 +2129,14 @@ class HubApi: revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`. endpoint (Optional[str]): The endpoint to use. In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai' - max_retries (int): Number of max retry attempts (default: 3). timeout (int): Timeout for each request in seconds (default: 180). Returns: CommitInfo: The commit info. Raises: - requests.exceptions.RequestException: If all retry attempts fail. + ValueError: If the request fails with a 4xx client error. + requests.exceptions.RequestException: If a network-level error occurs. """ if not repo_id: raise ValueError('Repo id cannot be empty!') @@ -2147,66 +2159,29 @@ class HubApi: commit_message=commit_message, ) - # POST with retry mechanism - last_exception = None - for attempt in range(max_retries): + response = self.session.post( + url, + headers=self.builder_headers(self.headers), + data=json.dumps(payload), + cookies=cookies, + ) + + if response.status_code != 200: try: - if attempt > 0: - logger.info(f'Attempt {attempt + 1} to create commit for {repo_id}...') - response = requests.post( - url, - headers=self.builder_headers(self.headers), - data=json.dumps(payload), - cookies=cookies, - timeout=timeout, - ) + error_detail = response.json() + except json.JSONDecodeError: + error_detail = response.text + error_msg = f'HTTP {response.status_code} error from {url}: {error_detail}' + raise ValueError(error_msg) - if response.status_code != 200: - try: - error_detail = response.json() - except json.JSONDecodeError: - error_detail = response.text - - error_msg = ( - f'HTTP {response.status_code} error from {url}: ' - f'{error_detail}' - ) - - # If server error (5xx), we can retry, otherwise (4xx) raise immediately - if 500 <= response.status_code < 600: - logger.warning( - f'Server error on attempt {attempt + 1}: {error_msg}' - ) - else: - raise ValueError(f'Client request failed: {error_msg}') - else: - resp = response.json() - - oid = resp.get('Data', {}).get('oid', '') - logger.info(f'Commit succeeded: {url}') - return CommitInfo( - commit_url=url, - commit_message=commit_message, - commit_description=commit_description, - oid=oid, - ) - - except requests.exceptions.RequestException as e: - last_exception = e - logger.warning(f'Request failed on attempt {attempt + 1}: {str(e)}') - - except Exception as e: - last_exception = e - logger.error(f'Unexpected error on attempt {attempt + 1}: {str(e)}') - if attempt == max_retries - 1: - raise - - if attempt < max_retries - 1: - time.sleep(1) - - # All retries exhausted - raise requests.exceptions.RequestException( - f'Failed to create commit after {max_retries} attempts. Last error: {last_exception}' + resp = response.json() + oid = resp.get('Data', {}).get('oid', '') + logger.info(f'Commit succeeded: {url}') + return CommitInfo( + commit_url=url, + commit_message=commit_message, + commit_description=commit_description, + oid=oid, ) def upload_file( @@ -2593,21 +2568,21 @@ class HubApi: if isinstance(data, (str, Path)): with open(data, 'rb') as f: - response = requests.put( + response = self.session.put( upload_object['url'], headers=headers, data=read_in_chunks(f, pbar) ) elif isinstance(data, bytes): - response = requests.put( + response = self.session.put( upload_object['url'], headers=headers, data=read_in_chunks(io.BytesIO(data), pbar) ) elif isinstance(data, io.BufferedIOBase): - response = requests.put( + response = self.session.put( upload_object['url'], headers=headers, data=read_in_chunks(data, pbar) @@ -2664,7 +2639,7 @@ class HubApi: } cookies = self.get_cookies(access_token=token, cookies_required=True) - response = requests.post( + response = self.session.post( url, headers=self.builder_headers(self.headers), data=json.dumps(payload), @@ -2907,6 +2882,9 @@ class HubApi: file_paths = [f['Path'] for f in files] elif repo_type == REPO_TYPE_DATASET: file_paths = [] + _owner, _dataset_name = repo_id.split('/') + _hub_id, _ = self.get_dataset_id_and_type( + dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token) page_number = 1 page_size = 100 while True: @@ -2919,6 +2897,7 @@ class HubApi: page_size=page_size, endpoint=endpoint, token=token, + dataset_hub_id=_hub_id, ) except Exception as e: logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}') diff --git a/modelscope/hub/file_download.py b/modelscope/hub/file_download.py index 33120b9b..320d1221 100644 --- a/modelscope/hub/file_download.py +++ b/modelscope/hub/file_download.py @@ -253,6 +253,11 @@ def _repo_file_download( group_or_owner, name = model_id_to_group_owner_name(repo_id) if not revision: revision = DEFAULT_DATASET_REVISION + _hub_id, _ = _api.get_dataset_id_and_type( + dataset_name=name, + namespace=group_or_owner, + endpoint=endpoint, + token=token) page_number = 1 page_size = 100 while True: @@ -265,7 +270,8 @@ def _repo_file_download( page_number=page_number, page_size=page_size, endpoint=endpoint, - token=token) + token=token, + dataset_hub_id=_hub_id) except Exception as e: logger.error( f'Get dataset: {repo_id} file list failed, error: {e}') diff --git a/modelscope/hub/snapshot_download.py b/modelscope/hub/snapshot_download.py index 0e21de4f..c29addd3 100644 --- a/modelscope/hub/snapshot_download.py +++ b/modelscope/hub/snapshot_download.py @@ -428,6 +428,10 @@ def _snapshot_download( def fetch_repo_files(_api, repo_id, revision, endpoint): + _owner, _dataset_name = repo_id.split('/') + _hub_id, _ = _api.get_dataset_id_and_type( + dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint) + page_number = 1 page_size = 150 repo_files = [] @@ -441,7 +445,8 @@ def fetch_repo_files(_api, repo_id, revision, endpoint): recursive=True, page_number=page_number, page_size=page_size, - endpoint=endpoint) + endpoint=endpoint, + dataset_hub_id=_hub_id) except Exception as e: logger.error(f'Error fetching dataset files: {e}') break diff --git a/modelscope/msdatasets/data_loader/data_loader.py b/modelscope/msdatasets/data_loader/data_loader.py index fa1c2ac9..fd6b1d59 100644 --- a/modelscope/msdatasets/data_loader/data_loader.py +++ b/modelscope/msdatasets/data_loader/data_loader.py @@ -148,7 +148,6 @@ class OssDownloader(BaseDownloader): data_files=data_files, cache_dir=cache_dir, download_mode=download_mode.value, - trust_remote_code=trust_remote_code, **input_kwargs) else: self.dataset = self.data_files_manager.fetch_data_files( diff --git a/modelscope/msdatasets/data_loader/data_loader_manager.py b/modelscope/msdatasets/data_loader/data_loader_manager.py index 011b6e47..71688711 100644 --- a/modelscope/msdatasets/data_loader/data_loader_manager.py +++ b/modelscope/msdatasets/data_loader/data_loader_manager.py @@ -87,7 +87,6 @@ class LocalDataLoaderManager(DataLoaderManager): cache_dir=cache_root_dir, download_mode=download_mode.value, streaming=use_streaming, - trust_remote_code=trust_remote_code, **input_config_kwargs) raise f'Expected local data loader type: {LocalDataLoaderType.HF_DATA_LOADER.value}.' @@ -130,7 +129,6 @@ class RemoteDataLoaderManager(DataLoaderManager): data_files=data_files, download_mode=download_mode_val, streaming=use_streaming, - trust_remote_code=trust_remote_code, token=token, **input_config_kwargs) # download statistics diff --git a/modelscope/msdatasets/download/dataset_builder.py b/modelscope/msdatasets/download/dataset_builder.py index 84563668..534a14f6 100644 --- a/modelscope/msdatasets/download/dataset_builder.py +++ b/modelscope/msdatasets/download/dataset_builder.py @@ -13,8 +13,8 @@ from datasets.filesystems import is_remote_filesystem from datasets.info import DatasetInfo from datasets.naming import camelcase_to_snakecase from datasets.packaged_modules import csv -from datasets.utils.filelock import FileLock from datasets.utils.py_utils import map_nested +from filelock import FileLock from modelscope.hub.api import HubApi from modelscope.msdatasets.context.dataset_context_config import \ diff --git a/modelscope/msdatasets/meta/data_meta_manager.py b/modelscope/msdatasets/meta/data_meta_manager.py index afef97b0..8fecf3ef 100644 --- a/modelscope/msdatasets/meta/data_meta_manager.py +++ b/modelscope/msdatasets/meta/data_meta_manager.py @@ -5,7 +5,7 @@ import shutil from collections import defaultdict import json -from datasets.utils.filelock import FileLock +from filelock import FileLock from modelscope.hub.api import HubApi from modelscope.msdatasets.context.dataset_context_config import \ diff --git a/modelscope/msdatasets/ms_dataset.py b/modelscope/msdatasets/ms_dataset.py index a6d223f3..d6c31946 100644 --- a/modelscope/msdatasets/ms_dataset.py +++ b/modelscope/msdatasets/ms_dataset.py @@ -283,10 +283,15 @@ class MsDataset: return load_dataset( dataset_name, name=subset_name, + data_dir=data_dir, + data_files=data_files, split=split, - streaming=use_streaming, + cache_dir=cache_dir, + features=features, download_mode=download_mode.value, - trust_remote_code=trust_remote_code, + revision=version, + token=token, + streaming=use_streaming, **config_kwargs) # Load from the modelscope hub diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py index 993714fe..ac02443d 100644 --- a/modelscope/msdatasets/utils/hf_datasets_util.py +++ b/modelscope/msdatasets/utils/hf_datasets_util.py @@ -17,7 +17,16 @@ import requests from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict, DownloadConfig, DownloadManager, DownloadMode, Features, IterableDataset, IterableDatasetDict, Split, - VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf) + VerificationMode, Version, config, data_files, LargeList, + Sequence as SequenceHf) + +# In datasets 4.0+, Sequence was replaced by List as a feature type. +# Use List as the base for ListMs when available, fall back to Sequence for <4.0. +try: + from datasets import List as DatasetList +except ImportError: + DatasetList = None + from datasets.features import features from datasets.features.features import _FEATURE_TYPES from datasets.data_files import ( @@ -29,36 +38,63 @@ from datasets.download.streaming_download_manager import ( from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError from datasets.info import DatasetInfosDict from datasets.load import ( - ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters, + BuilderConfigsParameters, CachedDatasetModuleFactory, DatasetModule, - HubDatasetModuleFactoryWithoutScript, HubDatasetModuleFactoryWithParquetExport, - HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript, - LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory, + PackagedDatasetModuleFactory, create_builder_configs_from_metadata_configs, get_dataset_builder_class, - import_main_class, infer_module_for_data_files, files_to_hash, - _get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file, - init_dynamic_modules) + import_main_class, infer_module_for_data_files) + +# To compatible with datasets 4.0+ +try: + from datasets.load import ( + HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript, + LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript) +except ImportError: + from datasets.load import ( + HubDatasetModuleFactoryWithoutScript, + LocalDatasetModuleFactoryWithoutScript) + +# Script-based dataset loading was removed in datasets 4.0. +# These APIs are conditionally imported for backward compatibility with <4.0. +try: + from datasets.load import ( + HubDatasetModuleFactoryWithScript, + LocalDatasetModuleFactoryWithScript, + resolve_trust_remote_code, + _get_importable_file_path, _create_importable_file, + _load_importable_file, init_dynamic_modules, + files_to_hash) + from datasets.utils.py_utils import get_imports + _HAS_SCRIPT_LOADING = True +except ImportError: + _HAS_SCRIPT_LOADING = False + from datasets.naming import camelcase_to_snakecase from datasets.packaged_modules import (_EXTENSION_TO_MODULE, _MODULE_TO_EXTENSIONS, _PACKAGED_DATASETS_MODULES) +# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0 +try: + from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS +except ImportError: + from datasets.load import ALL_ALLOWED_EXTENSIONS from datasets.utils import file_utils from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled, cached_path, is_local_path, relative_to_absolute_path) from datasets.utils.info_utils import is_small_dataset from datasets.utils.metadata import MetadataConfigs -from datasets.utils.py_utils import get_imports from datasets.utils.track import tracked_str from fsspec import filesystem from fsspec.core import _un_chain from fsspec.utils import stringify_path -from huggingface_hub import (DatasetCard, DatasetCardData) +from huggingface_hub import (DatasetCard, DatasetCardData, hf_hub_url) from huggingface_hub.errors import OfflineModeIsEnabled from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder +from huggingface_hub.hf_file_system import HfFileSystem from packaging import version from modelscope import HubApi @@ -94,8 +130,13 @@ ExpandDatasetProperty_T = Literal[ # Patch datasets features +# In datasets 4.0+, the List type is the native feature type; +# in datasets <4.0, Sequence (a dataclass) serves that role. +_ListBase = DatasetList if DatasetList is not None else SequenceHf + + @dataclass(repr=False) -class ListMs(SequenceHf): +class ListMs(_ListBase): """Feature type for large list data composed of child feature data type. It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length. @@ -144,6 +185,15 @@ def generate_from_dict_ms(obj: Any): return {key: generate_from_dict_ms(value) for key, value in obj.items()} obj = dict(obj) _type = obj.pop('_type') + + # Handle legacy 'Sequence' type for backward compatibility. + # In datasets 4.0+, Sequence is a utility function (not a feature type), + # so it may not be registered in _FEATURE_TYPES. + if _type == 'Sequence': + feature = obj.pop('feature') + length = obj.get('length', -1) + return SequenceHf(feature=generate_from_dict_ms(feature), length=length) + class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None) if class_type is None: @@ -155,9 +205,6 @@ def generate_from_dict_ms(obj: Any): if class_type == ListMs: feature = obj.pop('feature') return ListMs(generate_from_dict_ms(feature), **obj) - if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict - feature = obj.pop('feature') - return SequenceHf(feature=generate_from_dict_ms(feature), **obj) field_names = {f.name for f in fields(class_type)} return class_type(**{k: v for k, v in obj.items() if k in field_names}) @@ -165,14 +212,12 @@ def generate_from_dict_ms(obj: Any): def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str: url_or_filename = str(url_or_filename) - # for temp val - revision = None if url_or_filename.startswith('hf://'): - revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1) - if is_relative_path(url_or_filename): - # append the relative path to the base_path - # url_or_filename = url_or_path_join(self._base_path, url_or_filename) - revision = revision or DEFAULT_DATASET_REVISION + # hf:// URLs are handled natively by cached_path via HfApi.hf_hub_download, + # which uses config.HF_ENDPOINT (already set to ModelScope endpoint). + pass + elif is_relative_path(url_or_filename): + revision = DEFAULT_DATASET_REVISION # Note: make sure the FilePath is the last param params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename} params: str = urlencode(params) @@ -274,6 +319,37 @@ def _dataset_info( return HfDatasetInfo(**data_info) +_repo_tree_cache: Dict[tuple, List[Union[RepoFile, RepoFolder]]] = {} + + +def _derive_from_recursive_cache( + repo_id: str, + revision: str, + path_in_repo: str, + recursive: bool, +) -> Optional[List[Union[RepoFile, RepoFolder]]]: + """Try to derive results from a cached recursive root listing.""" + root_key = (repo_id, revision, '/', True) + root_cached = _repo_tree_cache.get(root_key) + if root_cached is None: + return None + + prefix = path_in_repo.strip('/') if path_in_repo and path_in_repo != '/' else '' + results = [] + for item in root_cached: + item_path = item.path + if prefix: + if not item_path.startswith(prefix + '/') and item_path != prefix: + continue + rel_path = item_path[len(prefix) + 1:] if item_path.startswith(prefix + '/') else '' + else: + rel_path = item_path + if not recursive and '/' in rel_path: + continue + results.append(item) + return results + + def _list_repo_tree( self, repo_id: str, @@ -286,41 +362,72 @@ def _list_repo_tree( token: Optional[Union[bool, str]] = None, ) -> Iterable[Union[RepoFile, RepoFolder]]: + revision = revision or DEFAULT_DATASET_REVISION + normalized_path = path_in_repo or '/' + cache_key = (repo_id, revision, normalized_path, recursive) + + cached = _repo_tree_cache.get(cache_key) + if cached is not None: + yield from cached + return + + derived = _derive_from_recursive_cache(repo_id, revision, normalized_path, recursive) + if derived is not None: + _repo_tree_cache[cache_key] = derived + yield from derived + return + _api = HubApi(timeout=3 * 60, max_retries=3) endpoint = _api.get_endpoint_for_read( repo_id=repo_id, repo_type=REPO_TYPE_DATASET) - # List all files in the repo + _owner, _dataset_name = repo_id.split('/') + dataset_hub_id, _ = _api.get_dataset_id_and_type( + dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint) + + results: List[Union[RepoFile, RepoFolder]] = [] page_number = 1 - page_size = 100 - while True: + # Larger page_size reduces the number of HTTP round-trips for big datasets. + # Termination uses `not dataset_files` (empty page) so it is safe even if + # the server silently caps the actual page size to a smaller value. + page_size = 500 + max_pages = 10000 + while page_number <= max_pages: try: dataset_files = _api.get_dataset_files( repo_id=repo_id, - revision=revision or DEFAULT_DATASET_REVISION, - root_path=path_in_repo or '/', + revision=revision, + root_path=normalized_path, recursive=recursive, page_number=page_number, page_size=page_size, endpoint=endpoint, + dataset_hub_id=dataset_hub_id, ) except Exception as e: logger.error(f'Get dataset: {repo_id} file list failed, message: {e}') break - for file_info_d in dataset_files: - path_info = {} - path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file' - path_info['path'] = file_info_d['Path'] - path_info['size'] = file_info_d['Size'] - path_info['oid'] = file_info_d['Sha256'] + if not dataset_files: + break - yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info) + for file_info_d in dataset_files: + path_info = { + 'type': 'directory' if file_info_d['Type'] == 'tree' else 'file', + 'path': file_info_d['Path'], + 'size': file_info_d['Size'], + 'oid': file_info_d['Sha256'], + } + item = RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info) + results.append(item) + yield item if len(dataset_files) < page_size: break page_number += 1 + _repo_tree_cache[cache_key] = results + def _get_paths_info( self, @@ -333,7 +440,19 @@ def _get_paths_info( token: Optional[Union[bool, str]] = None, ) -> List[Union[RepoFile, RepoFolder]]: - # Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree` + revision = revision or DEFAULT_DATASET_REVISION + if isinstance(paths, str): + paths = [paths] + paths_set = set(paths) + + # Search within any cached tree data (recursive root is the most comprehensive) + root_key = (repo_id, revision, '/', True) + root_cached = _repo_tree_cache.get(root_key) + if root_cached is not None: + matched = [item for item in root_cached if item.path in paths_set] + if matched: + return matched + repo_info_iter = self.list_repo_tree( repo_id=repo_id, recursive=False, @@ -831,6 +950,10 @@ def _download_additional_modules( def get_module_with_script(self) -> DatasetModule: + if not _HAS_SCRIPT_LOADING: + raise RuntimeError( + 'Script-based dataset loading is not supported with datasets>=4.0. ' + 'Please convert the dataset to a script-free format (e.g. Parquet).') repo_id: str = self.name _namespace, _dataset_name = repo_id.split('/') @@ -1000,9 +1123,12 @@ class DatasetsWrapperHF: ) if not save_infos else VerificationMode.ALL_CHECKS) if trust_remote_code: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.' - ) + if not _HAS_SCRIPT_LOADING: + logger.warning('trust_remote_code is ignored: script-based dataset loading ' + 'is no longer supported with datasets>=4.0.') + else: + logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' + 'that you can trust the external codes.') # Create a dataset builder builder_instance = DatasetsWrapperHF.load_dataset_builder( @@ -1017,7 +1143,7 @@ class DatasetsWrapperHF: revision=revision, token=token, storage_options=storage_options, - trust_remote_code=trust_remote_code, + trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None, _require_default_config_name=name is None, **config_kwargs, ) @@ -1135,9 +1261,12 @@ class DatasetsWrapperHF: download_config.storage_options.update(storage_options) if trust_remote_code: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.' - ) + if not _HAS_SCRIPT_LOADING: + logger.warning('trust_remote_code is ignored: script-based dataset loading ' + 'is no longer supported with datasets>=4.0.') + else: + logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' + 'that you can trust the external codes.') dataset_module = DatasetsWrapperHF.dataset_module_factory( path, @@ -1147,7 +1276,7 @@ class DatasetsWrapperHF: data_dir=data_dir, data_files=data_files, cache_dir=cache_dir, - trust_remote_code=trust_remote_code, + trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None, _require_default_config_name=_require_default_config_name, _require_custom_configs=bool(config_kwargs), name=name, @@ -1250,9 +1379,12 @@ class DatasetsWrapperHF: # - if path has one "/" and is dataset repository on the HF hub without a python file # -> use a packaged module (csv, text etc.) based on content of the repository if trust_remote_code: - logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' - 'that you can trust the external codes.' - ) + if not _HAS_SCRIPT_LOADING: + logger.warning('trust_remote_code is ignored: script-based dataset loading ' + 'is no longer supported with datasets>=4.0.') + else: + logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure ' + 'that you can trust the external codes.') # Try packaged if path in _PACKAGED_DATASETS_MODULES: @@ -1263,9 +1395,13 @@ class DatasetsWrapperHF: download_config=download_config, download_mode=download_mode, ).get_module() - # Try locally + # Try locally with script (requires datasets <4.0) elif path.endswith(filename): if os.path.isfile(path): + if not _HAS_SCRIPT_LOADING: + raise RuntimeError( + f'Script-based dataset loading ({path}) is not supported with datasets>=4.0. ' + 'Please convert the dataset to a script-free format (e.g. Parquet).') return LocalDatasetModuleFactoryWithScript( path, download_mode=download_mode, @@ -1277,6 +1413,10 @@ class DatasetsWrapperHF: f"Couldn't find a dataset script at {relative_to_absolute_path(path)}" ) elif os.path.isfile(combined_path): + if not _HAS_SCRIPT_LOADING: + raise RuntimeError( + f'Script-based dataset loading ({combined_path}) is not supported with datasets>=4.0. ' + 'Please convert the dataset to a script-free format (e.g. Parquet).') return LocalDatasetModuleFactoryWithScript( combined_path, download_mode=download_mode, @@ -1342,24 +1482,8 @@ class DatasetsWrapperHF: sibling.rfilename for sibling in dataset_info.siblings ]: # contains a dataset script - # fs = HfFileSystem( - # endpoint=config.HF_ENDPOINT, - # token=download_config.token) - # TODO can_load_config_from_parquet_export = False - # if _require_custom_configs: - # can_load_config_from_parquet_export = False - # elif _require_default_config_name: - # with fs.open( - # f'datasets/{path}/{filename}', - # 'r', - # revision=revision, - # encoding='utf-8') as f: - # can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read( - # ) - # else: - # can_load_config_from_parquet_export = True if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export: # If the parquet export is ready (parquet files + info available for the current sha), # we can use it instead @@ -1379,7 +1503,14 @@ class DatasetsWrapperHF: except Exception as e: logger.error(e) - # Otherwise we must use the dataset script if the user trusts it + # Otherwise we must use the dataset script if the user trusts it. + # Script-based loading was removed in datasets 4.0. + if not _HAS_SCRIPT_LOADING: + raise RuntimeError( + f"Dataset '{path}' contains a loading script but script-based dataset loading " + 'is not supported with datasets>=4.0. Please convert the dataset to a ' + 'script-free format (e.g. Parquet).') + # To be adapted to the old version of datasets if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'): return HubDatasetModuleFactoryWithScript( @@ -1424,10 +1555,12 @@ class DatasetsWrapperHF: logger.error(f'>> Error loading {path}: {e1}') try: + # dynamic_modules_path was removed in datasets 4.0 + _cached_factory_kwargs = {'cache_dir': cache_dir} + if _HAS_SCRIPT_LOADING: + _cached_factory_kwargs['dynamic_modules_path'] = dynamic_modules_path return CachedDatasetModuleFactory( - path, - dynamic_modules_path=dynamic_modules_path, - cache_dir=cache_dir).get_module() + path, **_cached_factory_kwargs).get_module() except Exception: # If it's not in the cache, then it doesn't exist. if isinstance(e1, OfflineModeIsEnabled): @@ -1451,8 +1584,55 @@ class DatasetsWrapperHF: f'any data file in the same directory.') +_hf_fs_open_original = None + + +def _hf_fs_open(self, path, mode='rb', **kwargs): + """Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API. + + The ModelScope tree API may report Size=0 for files. When HfFileSystem + caches this, AbstractBufferedFile treats the file as empty (0 bytes). + This wrapper detects size=0 for files opened in read mode and resolves + the actual size via a HEAD request before creating the file object. + """ + if mode == 'rb' and 'size' not in kwargs: + try: + resolved = self.resolve_path(path) + resolved_name = resolved.unresolve() + parent = self._parent(resolved_name) + cached_size = None + if parent in self.dircache: + for entry in self.dircache[parent]: + if entry['name'] == resolved_name and entry.get('type') == 'file': + cached_size = entry.get('size', -1) + break + if cached_size == 0: + url = hf_hub_url( + repo_id=resolved.repo_id, + revision=resolved.revision, + filename=resolved.path_in_repo, + repo_type=resolved.repo_type, + endpoint=self.endpoint, + ) + headers = self._api._build_hf_headers() + resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30) + if resp.status_code == 200: + cl = resp.headers.get('Content-Length') + if cl: + actual_size = int(cl) + kwargs['size'] = actual_size + for entry in self.dircache.get(parent, []): + if entry['name'] == resolved_name: + entry['size'] = actual_size + break + except Exception: + pass + return _hf_fs_open_original(self, path, mode=mode, **kwargs) + + @contextlib.contextmanager def load_dataset_with_ctx(*args, **kwargs): + global _hf_fs_open_original # Keep the original functions hf_endpoint_origin = config.HF_ENDPOINT @@ -1467,8 +1647,11 @@ def load_dataset_with_ctx(*args, **kwargs): get_paths_info_origin = HfApi.get_paths_info resolve_pattern_origin = data_files.resolve_pattern get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module - get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module + # Script-based loading was removed in datasets 4.0 + get_module_with_script_origin = ( + HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None) generate_from_dict_origin = features.generate_from_dict + hf_fs_open_origin = HfFileSystem._open # Monkey patching with modelscope functions config.HF_ENDPOINT = get_endpoint() @@ -1483,8 +1666,11 @@ def load_dataset_with_ctx(*args, **kwargs): HfApi.get_paths_info = _get_paths_info data_files.resolve_pattern = _resolve_pattern HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script - HubDatasetModuleFactoryWithScript.get_module = get_module_with_script + if _HAS_SCRIPT_LOADING: + HubDatasetModuleFactoryWithScript.get_module = get_module_with_script features.generate_from_dict = generate_from_dict_ms + _hf_fs_open_original = hf_fs_open_origin + HfFileSystem._open = _hf_fs_open streaming = kwargs.get('streaming', False) @@ -1492,14 +1678,16 @@ def load_dataset_with_ctx(*args, **kwargs): dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs) yield dataset_res finally: - # Restore the original functions - config.HF_ENDPOINT = hf_endpoint_origin - file_utils.get_from_cache = get_from_cache_origin - features.generate_from_dict = generate_from_dict_origin - # Keep the context during the streaming iteration + _repo_tree_cache.clear() + HubApi._dataset_id_type_cache.clear() + + HfFileSystem._open = hf_fs_open_origin + _hf_fs_open_original = None + if not streaming: config.HF_ENDPOINT = hf_endpoint_origin file_utils.get_from_cache = get_from_cache_origin + features.generate_from_dict = generate_from_dict_origin # Compatible with datasets 2.18.0 if hasattr(DownloadManager, '_download'): @@ -1512,4 +1700,5 @@ def load_dataset_with_ctx(*args, **kwargs): HfApi.get_paths_info = get_paths_info_origin data_files.resolve_pattern = resolve_pattern_origin HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin - HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin + if _HAS_SCRIPT_LOADING: + HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin diff --git a/requirements/datasets.txt b/requirements/datasets.txt index d225d0de..21337757 100644 --- a/requirements/datasets.txt +++ b/requirements/datasets.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=3.0.0,<=3.6.0 +datasets>=4.0.0,<=4.6.1 einops oss2 Pillow diff --git a/requirements/framework.txt b/requirements/framework.txt index a6f3e65d..b0c255dc 100644 --- a/requirements/framework.txt +++ b/requirements/framework.txt @@ -1,6 +1,6 @@ addict attrs -datasets>=3.0.0,<=3.6.0 +datasets>=4.0.0,<=4.6.1 einops Pillow python-dateutil>=2.1