[Feat] dataset module refactor (#1623)

This commit is contained in:
Xingjun.Wang
2026-03-09 17:35:31 +08:00
committed by GitHub
parent 2fedc9bb9e
commit 3c3b78f029
11 changed files with 342 additions and 161 deletions

View File

@@ -11,7 +11,6 @@ import platform
import re
import shutil
import tempfile
import time
import uuid
import warnings
from collections import defaultdict
@@ -1403,6 +1402,8 @@ class HubApi:
raise_for_http_status(r)
raise_on_error(r.json())
_dataset_id_type_cache: dict = {}
def get_dataset_id_and_type(self,
dataset_name: str,
namespace: str,
@@ -1411,6 +1412,10 @@ class HubApi:
""" Get the dataset id and type. """
if not endpoint:
endpoint = self.endpoint
cache_key = (namespace, dataset_name, endpoint)
cached = HubApi._dataset_id_type_cache.get(cache_key)
if cached is not None:
return cached
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
cookies = self.get_cookies(access_token=token)
r = self.session.get(datahub_url, cookies=cookies)
@@ -1418,6 +1423,7 @@ class HubApi:
datahub_raise_on_error(datahub_url, resp, r)
dataset_id = resp['Data']['Id']
dataset_type = resp['Data']['Type']
HubApi._dataset_id_type_cache[cache_key] = (dataset_id, dataset_type)
return dataset_id, dataset_type
def list_repo_tree(self,
@@ -1526,7 +1532,8 @@ class HubApi:
page_number: int = 1,
page_size: int = 100,
endpoint: Optional[str] = None,
token: Optional[str] = None):
token: Optional[str] = None,
dataset_hub_id: Optional[str] = None):
"""
Get the dataset files.
@@ -1539,19 +1546,23 @@ class HubApi:
page_size (int): The number of items per page. Defaults to 100.
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
token (Optional[str]): The access token.
dataset_hub_id (Optional[str]): Pre-fetched dataset hub id. When provided,
skips the internal ``get_dataset_id_and_type`` lookup. Useful in pagination
loops to avoid redundant API calls per page.
Returns:
List: The response containing the dataset repository tree information.
e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...]
"""
if is_relative_path(repo_id) and repo_id.count('/') == 1:
_owner, _dataset_name = repo_id.split('/')
else:
raise ValueError(f'Invalid repo_id: {repo_id} !')
if dataset_hub_id is None:
if is_relative_path(repo_id) and repo_id.count('/') == 1:
_owner, _dataset_name = repo_id.split('/')
else:
raise ValueError(f'Invalid repo_id: {repo_id} !')
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
dataset_hub_id, _ = self.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
if not endpoint:
endpoint = self.endpoint
@@ -1569,7 +1580,10 @@ class HubApi:
resp = r.json()
datahub_raise_on_error(datahub_url, resp, r)
return resp['Data']['Files']
data = resp.get('Data')
if data is None:
return []
return data.get('Files') or []
def get_dataset(
self,
@@ -2100,11 +2114,9 @@ class HubApi:
repo_type: Optional[str] = REPO_TYPE_MODEL,
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
endpoint: Optional[str] = None,
max_retries: int = 3,
timeout: int = 180,
) -> CommitInfo:
"""
Create a commit on the ModelScope Hub with retry mechanism.
Create a commit on the ModelScope Hub.
Args:
repo_id (str): The repo id in the format of `owner_name/repo_name`.
@@ -2117,14 +2129,14 @@ class HubApi:
revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
endpoint (Optional[str]): The endpoint to use.
In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
max_retries (int): Number of max retry attempts (default: 3).
timeout (int): Timeout for each request in seconds (default: 180).
Returns:
CommitInfo: The commit info.
Raises:
requests.exceptions.RequestException: If all retry attempts fail.
ValueError: If the request fails with a 4xx client error.
requests.exceptions.RequestException: If a network-level error occurs.
"""
if not repo_id:
raise ValueError('Repo id cannot be empty!')
@@ -2147,66 +2159,29 @@ class HubApi:
commit_message=commit_message,
)
# POST with retry mechanism
last_exception = None
for attempt in range(max_retries):
response = self.session.post(
url,
headers=self.builder_headers(self.headers),
data=json.dumps(payload),
cookies=cookies,
)
if response.status_code != 200:
try:
if attempt > 0:
logger.info(f'Attempt {attempt + 1} to create commit for {repo_id}...')
response = requests.post(
url,
headers=self.builder_headers(self.headers),
data=json.dumps(payload),
cookies=cookies,
timeout=timeout,
)
error_detail = response.json()
except json.JSONDecodeError:
error_detail = response.text
error_msg = f'HTTP {response.status_code} error from {url}: {error_detail}'
raise ValueError(error_msg)
if response.status_code != 200:
try:
error_detail = response.json()
except json.JSONDecodeError:
error_detail = response.text
error_msg = (
f'HTTP {response.status_code} error from {url}: '
f'{error_detail}'
)
# If server error (5xx), we can retry, otherwise (4xx) raise immediately
if 500 <= response.status_code < 600:
logger.warning(
f'Server error on attempt {attempt + 1}: {error_msg}'
)
else:
raise ValueError(f'Client request failed: {error_msg}')
else:
resp = response.json()
oid = resp.get('Data', {}).get('oid', '')
logger.info(f'Commit succeeded: {url}')
return CommitInfo(
commit_url=url,
commit_message=commit_message,
commit_description=commit_description,
oid=oid,
)
except requests.exceptions.RequestException as e:
last_exception = e
logger.warning(f'Request failed on attempt {attempt + 1}: {str(e)}')
except Exception as e:
last_exception = e
logger.error(f'Unexpected error on attempt {attempt + 1}: {str(e)}')
if attempt == max_retries - 1:
raise
if attempt < max_retries - 1:
time.sleep(1)
# All retries exhausted
raise requests.exceptions.RequestException(
f'Failed to create commit after {max_retries} attempts. Last error: {last_exception}'
resp = response.json()
oid = resp.get('Data', {}).get('oid', '')
logger.info(f'Commit succeeded: {url}')
return CommitInfo(
commit_url=url,
commit_message=commit_message,
commit_description=commit_description,
oid=oid,
)
def upload_file(
@@ -2593,21 +2568,21 @@ class HubApi:
if isinstance(data, (str, Path)):
with open(data, 'rb') as f:
response = requests.put(
response = self.session.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(f, pbar)
)
elif isinstance(data, bytes):
response = requests.put(
response = self.session.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(io.BytesIO(data), pbar)
)
elif isinstance(data, io.BufferedIOBase):
response = requests.put(
response = self.session.put(
upload_object['url'],
headers=headers,
data=read_in_chunks(data, pbar)
@@ -2664,7 +2639,7 @@ class HubApi:
}
cookies = self.get_cookies(access_token=token, cookies_required=True)
response = requests.post(
response = self.session.post(
url,
headers=self.builder_headers(self.headers),
data=json.dumps(payload),
@@ -2907,6 +2882,9 @@ class HubApi:
file_paths = [f['Path'] for f in files]
elif repo_type == REPO_TYPE_DATASET:
file_paths = []
_owner, _dataset_name = repo_id.split('/')
_hub_id, _ = self.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
page_number = 1
page_size = 100
while True:
@@ -2919,6 +2897,7 @@ class HubApi:
page_size=page_size,
endpoint=endpoint,
token=token,
dataset_hub_id=_hub_id,
)
except Exception as e:
logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')

View File

@@ -253,6 +253,11 @@ def _repo_file_download(
group_or_owner, name = model_id_to_group_owner_name(repo_id)
if not revision:
revision = DEFAULT_DATASET_REVISION
_hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=name,
namespace=group_or_owner,
endpoint=endpoint,
token=token)
page_number = 1
page_size = 100
while True:
@@ -265,7 +270,8 @@ def _repo_file_download(
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
token=token)
token=token,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(
f'Get dataset: {repo_id} file list failed, error: {e}')

View File

@@ -428,6 +428,10 @@ def _snapshot_download(
def fetch_repo_files(_api, repo_id, revision, endpoint):
_owner, _dataset_name = repo_id.split('/')
_hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
page_number = 1
page_size = 150
repo_files = []
@@ -441,7 +445,8 @@ def fetch_repo_files(_api, repo_id, revision, endpoint):
recursive=True,
page_number=page_number,
page_size=page_size,
endpoint=endpoint)
endpoint=endpoint,
dataset_hub_id=_hub_id)
except Exception as e:
logger.error(f'Error fetching dataset files: {e}')
break

View File

@@ -148,7 +148,6 @@ class OssDownloader(BaseDownloader):
data_files=data_files,
cache_dir=cache_dir,
download_mode=download_mode.value,
trust_remote_code=trust_remote_code,
**input_kwargs)
else:
self.dataset = self.data_files_manager.fetch_data_files(

View File

@@ -87,7 +87,6 @@ class LocalDataLoaderManager(DataLoaderManager):
cache_dir=cache_root_dir,
download_mode=download_mode.value,
streaming=use_streaming,
trust_remote_code=trust_remote_code,
**input_config_kwargs)
raise f'Expected local data loader type: {LocalDataLoaderType.HF_DATA_LOADER.value}.'
@@ -130,7 +129,6 @@ class RemoteDataLoaderManager(DataLoaderManager):
data_files=data_files,
download_mode=download_mode_val,
streaming=use_streaming,
trust_remote_code=trust_remote_code,
token=token,
**input_config_kwargs)
# download statistics

View File

@@ -13,8 +13,8 @@ from datasets.filesystems import is_remote_filesystem
from datasets.info import DatasetInfo
from datasets.naming import camelcase_to_snakecase
from datasets.packaged_modules import csv
from datasets.utils.filelock import FileLock
from datasets.utils.py_utils import map_nested
from filelock import FileLock
from modelscope.hub.api import HubApi
from modelscope.msdatasets.context.dataset_context_config import \

View File

@@ -5,7 +5,7 @@ import shutil
from collections import defaultdict
import json
from datasets.utils.filelock import FileLock
from filelock import FileLock
from modelscope.hub.api import HubApi
from modelscope.msdatasets.context.dataset_context_config import \

View File

@@ -283,10 +283,15 @@ class MsDataset:
return load_dataset(
dataset_name,
name=subset_name,
data_dir=data_dir,
data_files=data_files,
split=split,
streaming=use_streaming,
cache_dir=cache_dir,
features=features,
download_mode=download_mode.value,
trust_remote_code=trust_remote_code,
revision=version,
token=token,
streaming=use_streaming,
**config_kwargs)
# Load from the modelscope hub

View File

@@ -17,7 +17,16 @@ import requests
from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
DownloadConfig, DownloadManager, DownloadMode, Features,
IterableDataset, IterableDatasetDict, Split,
VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf)
VerificationMode, Version, config, data_files, LargeList,
Sequence as SequenceHf)
# In datasets 4.0+, Sequence was replaced by List as a feature type.
# Use List as the base for ListMs when available, fall back to Sequence for <4.0.
try:
from datasets import List as DatasetList
except ImportError:
DatasetList = None
from datasets.features import features
from datasets.features.features import _FEATURE_TYPES
from datasets.data_files import (
@@ -29,36 +38,63 @@ from datasets.download.streaming_download_manager import (
from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
from datasets.info import DatasetInfosDict
from datasets.load import (
ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters,
BuilderConfigsParameters,
CachedDatasetModuleFactory, DatasetModule,
HubDatasetModuleFactoryWithoutScript,
HubDatasetModuleFactoryWithParquetExport,
HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript,
LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory,
PackagedDatasetModuleFactory,
create_builder_configs_from_metadata_configs, get_dataset_builder_class,
import_main_class, infer_module_for_data_files, files_to_hash,
_get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file,
init_dynamic_modules)
import_main_class, infer_module_for_data_files)
# To compatible with datasets 4.0+
try:
from datasets.load import (
HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript,
LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript)
except ImportError:
from datasets.load import (
HubDatasetModuleFactoryWithoutScript,
LocalDatasetModuleFactoryWithoutScript)
# Script-based dataset loading was removed in datasets 4.0.
# These APIs are conditionally imported for backward compatibility with <4.0.
try:
from datasets.load import (
HubDatasetModuleFactoryWithScript,
LocalDatasetModuleFactoryWithScript,
resolve_trust_remote_code,
_get_importable_file_path, _create_importable_file,
_load_importable_file, init_dynamic_modules,
files_to_hash)
from datasets.utils.py_utils import get_imports
_HAS_SCRIPT_LOADING = True
except ImportError:
_HAS_SCRIPT_LOADING = False
from datasets.naming import camelcase_to_snakecase
from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
_MODULE_TO_EXTENSIONS,
_PACKAGED_DATASETS_MODULES)
# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0
try:
from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS
except ImportError:
from datasets.load import ALL_ALLOWED_EXTENSIONS
from datasets.utils import file_utils
from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
cached_path, is_local_path,
relative_to_absolute_path)
from datasets.utils.info_utils import is_small_dataset
from datasets.utils.metadata import MetadataConfigs
from datasets.utils.py_utils import get_imports
from datasets.utils.track import tracked_str
from fsspec import filesystem
from fsspec.core import _un_chain
from fsspec.utils import stringify_path
from huggingface_hub import (DatasetCard, DatasetCardData)
from huggingface_hub import (DatasetCard, DatasetCardData, hf_hub_url)
from huggingface_hub.errors import OfflineModeIsEnabled
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
from huggingface_hub.hf_file_system import HfFileSystem
from packaging import version
from modelscope import HubApi
@@ -94,8 +130,13 @@ ExpandDatasetProperty_T = Literal[
# Patch datasets features
# In datasets 4.0+, the List type is the native feature type;
# in datasets <4.0, Sequence (a dataclass) serves that role.
_ListBase = DatasetList if DatasetList is not None else SequenceHf
@dataclass(repr=False)
class ListMs(SequenceHf):
class ListMs(_ListBase):
"""Feature type for large list data composed of child feature data type.
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
@@ -144,6 +185,15 @@ def generate_from_dict_ms(obj: Any):
return {key: generate_from_dict_ms(value) for key, value in obj.items()}
obj = dict(obj)
_type = obj.pop('_type')
# Handle legacy 'Sequence' type for backward compatibility.
# In datasets 4.0+, Sequence is a utility function (not a feature type),
# so it may not be registered in _FEATURE_TYPES.
if _type == 'Sequence':
feature = obj.pop('feature')
length = obj.get('length', -1)
return SequenceHf(feature=generate_from_dict_ms(feature), length=length)
class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)
if class_type is None:
@@ -155,9 +205,6 @@ def generate_from_dict_ms(obj: Any):
if class_type == ListMs:
feature = obj.pop('feature')
return ListMs(generate_from_dict_ms(feature), **obj)
if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict
feature = obj.pop('feature')
return SequenceHf(feature=generate_from_dict_ms(feature), **obj)
field_names = {f.name for f in fields(class_type)}
return class_type(**{k: v for k, v in obj.items() if k in field_names})
@@ -165,14 +212,12 @@ def generate_from_dict_ms(obj: Any):
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
# for temp val
revision = None
if url_or_filename.startswith('hf://'):
revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1)
if is_relative_path(url_or_filename):
# append the relative path to the base_path
# url_or_filename = url_or_path_join(self._base_path, url_or_filename)
revision = revision or DEFAULT_DATASET_REVISION
# hf:// URLs are handled natively by cached_path via HfApi.hf_hub_download,
# which uses config.HF_ENDPOINT (already set to ModelScope endpoint).
pass
elif is_relative_path(url_or_filename):
revision = DEFAULT_DATASET_REVISION
# Note: make sure the FilePath is the last param
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
params: str = urlencode(params)
@@ -274,6 +319,37 @@ def _dataset_info(
return HfDatasetInfo(**data_info)
_repo_tree_cache: Dict[tuple, List[Union[RepoFile, RepoFolder]]] = {}
def _derive_from_recursive_cache(
repo_id: str,
revision: str,
path_in_repo: str,
recursive: bool,
) -> Optional[List[Union[RepoFile, RepoFolder]]]:
"""Try to derive results from a cached recursive root listing."""
root_key = (repo_id, revision, '/', True)
root_cached = _repo_tree_cache.get(root_key)
if root_cached is None:
return None
prefix = path_in_repo.strip('/') if path_in_repo and path_in_repo != '/' else ''
results = []
for item in root_cached:
item_path = item.path
if prefix:
if not item_path.startswith(prefix + '/') and item_path != prefix:
continue
rel_path = item_path[len(prefix) + 1:] if item_path.startswith(prefix + '/') else ''
else:
rel_path = item_path
if not recursive and '/' in rel_path:
continue
results.append(item)
return results
def _list_repo_tree(
self,
repo_id: str,
@@ -286,41 +362,72 @@ def _list_repo_tree(
token: Optional[Union[bool, str]] = None,
) -> Iterable[Union[RepoFile, RepoFolder]]:
revision = revision or DEFAULT_DATASET_REVISION
normalized_path = path_in_repo or '/'
cache_key = (repo_id, revision, normalized_path, recursive)
cached = _repo_tree_cache.get(cache_key)
if cached is not None:
yield from cached
return
derived = _derive_from_recursive_cache(repo_id, revision, normalized_path, recursive)
if derived is not None:
_repo_tree_cache[cache_key] = derived
yield from derived
return
_api = HubApi(timeout=3 * 60, max_retries=3)
endpoint = _api.get_endpoint_for_read(
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
# List all files in the repo
_owner, _dataset_name = repo_id.split('/')
dataset_hub_id, _ = _api.get_dataset_id_and_type(
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
results: List[Union[RepoFile, RepoFolder]] = []
page_number = 1
page_size = 100
while True:
# Larger page_size reduces the number of HTTP round-trips for big datasets.
# Termination uses `not dataset_files` (empty page) so it is safe even if
# the server silently caps the actual page size to a smaller value.
page_size = 500
max_pages = 10000
while page_number <= max_pages:
try:
dataset_files = _api.get_dataset_files(
repo_id=repo_id,
revision=revision or DEFAULT_DATASET_REVISION,
root_path=path_in_repo or '/',
revision=revision,
root_path=normalized_path,
recursive=recursive,
page_number=page_number,
page_size=page_size,
endpoint=endpoint,
dataset_hub_id=dataset_hub_id,
)
except Exception as e:
logger.error(f'Get dataset: {repo_id} file list failed, message: {e}')
break
for file_info_d in dataset_files:
path_info = {}
path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
path_info['path'] = file_info_d['Path']
path_info['size'] = file_info_d['Size']
path_info['oid'] = file_info_d['Sha256']
if not dataset_files:
break
yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
for file_info_d in dataset_files:
path_info = {
'type': 'directory' if file_info_d['Type'] == 'tree' else 'file',
'path': file_info_d['Path'],
'size': file_info_d['Size'],
'oid': file_info_d['Sha256'],
}
item = RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
results.append(item)
yield item
if len(dataset_files) < page_size:
break
page_number += 1
_repo_tree_cache[cache_key] = results
def _get_paths_info(
self,
@@ -333,7 +440,19 @@ def _get_paths_info(
token: Optional[Union[bool, str]] = None,
) -> List[Union[RepoFile, RepoFolder]]:
# Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
revision = revision or DEFAULT_DATASET_REVISION
if isinstance(paths, str):
paths = [paths]
paths_set = set(paths)
# Search within any cached tree data (recursive root is the most comprehensive)
root_key = (repo_id, revision, '/', True)
root_cached = _repo_tree_cache.get(root_key)
if root_cached is not None:
matched = [item for item in root_cached if item.path in paths_set]
if matched:
return matched
repo_info_iter = self.list_repo_tree(
repo_id=repo_id,
recursive=False,
@@ -831,6 +950,10 @@ def _download_additional_modules(
def get_module_with_script(self) -> DatasetModule:
if not _HAS_SCRIPT_LOADING:
raise RuntimeError(
'Script-based dataset loading is not supported with datasets>=4.0. '
'Please convert the dataset to a script-free format (e.g. Parquet).')
repo_id: str = self.name
_namespace, _dataset_name = repo_id.split('/')
@@ -1000,9 +1123,12 @@ class DatasetsWrapperHF:
) if not save_infos else VerificationMode.ALL_CHECKS)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
if not _HAS_SCRIPT_LOADING:
logger.warning('trust_remote_code is ignored: script-based dataset loading '
'is no longer supported with datasets>=4.0.')
else:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.')
# Create a dataset builder
builder_instance = DatasetsWrapperHF.load_dataset_builder(
@@ -1017,7 +1143,7 @@ class DatasetsWrapperHF:
revision=revision,
token=token,
storage_options=storage_options,
trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
_require_default_config_name=name is None,
**config_kwargs,
)
@@ -1135,9 +1261,12 @@ class DatasetsWrapperHF:
download_config.storage_options.update(storage_options)
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
if not _HAS_SCRIPT_LOADING:
logger.warning('trust_remote_code is ignored: script-based dataset loading '
'is no longer supported with datasets>=4.0.')
else:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.')
dataset_module = DatasetsWrapperHF.dataset_module_factory(
path,
@@ -1147,7 +1276,7 @@ class DatasetsWrapperHF:
data_dir=data_dir,
data_files=data_files,
cache_dir=cache_dir,
trust_remote_code=trust_remote_code,
trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
_require_default_config_name=_require_default_config_name,
_require_custom_configs=bool(config_kwargs),
name=name,
@@ -1250,9 +1379,12 @@ class DatasetsWrapperHF:
# - if path has one "/" and is dataset repository on the HF hub without a python file
# -> use a packaged module (csv, text etc.) based on content of the repository
if trust_remote_code:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.'
)
if not _HAS_SCRIPT_LOADING:
logger.warning('trust_remote_code is ignored: script-based dataset loading '
'is no longer supported with datasets>=4.0.')
else:
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
'that you can trust the external codes.')
# Try packaged
if path in _PACKAGED_DATASETS_MODULES:
@@ -1263,9 +1395,13 @@ class DatasetsWrapperHF:
download_config=download_config,
download_mode=download_mode,
).get_module()
# Try locally
# Try locally with script (requires datasets <4.0)
elif path.endswith(filename):
if os.path.isfile(path):
if not _HAS_SCRIPT_LOADING:
raise RuntimeError(
f'Script-based dataset loading ({path}) is not supported with datasets>=4.0. '
'Please convert the dataset to a script-free format (e.g. Parquet).')
return LocalDatasetModuleFactoryWithScript(
path,
download_mode=download_mode,
@@ -1277,6 +1413,10 @@ class DatasetsWrapperHF:
f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
)
elif os.path.isfile(combined_path):
if not _HAS_SCRIPT_LOADING:
raise RuntimeError(
f'Script-based dataset loading ({combined_path}) is not supported with datasets>=4.0. '
'Please convert the dataset to a script-free format (e.g. Parquet).')
return LocalDatasetModuleFactoryWithScript(
combined_path,
download_mode=download_mode,
@@ -1342,24 +1482,8 @@ class DatasetsWrapperHF:
sibling.rfilename for sibling in dataset_info.siblings
]: # contains a dataset script
# fs = HfFileSystem(
# endpoint=config.HF_ENDPOINT,
# token=download_config.token)
# TODO
can_load_config_from_parquet_export = False
# if _require_custom_configs:
# can_load_config_from_parquet_export = False
# elif _require_default_config_name:
# with fs.open(
# f'datasets/{path}/{filename}',
# 'r',
# revision=revision,
# encoding='utf-8') as f:
# can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read(
# )
# else:
# can_load_config_from_parquet_export = True
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
# If the parquet export is ready (parquet files + info available for the current sha),
# we can use it instead
@@ -1379,7 +1503,14 @@ class DatasetsWrapperHF:
except Exception as e:
logger.error(e)
# Otherwise we must use the dataset script if the user trusts it
# Otherwise we must use the dataset script if the user trusts it.
# Script-based loading was removed in datasets 4.0.
if not _HAS_SCRIPT_LOADING:
raise RuntimeError(
f"Dataset '{path}' contains a loading script but script-based dataset loading "
'is not supported with datasets>=4.0. Please convert the dataset to a '
'script-free format (e.g. Parquet).')
# To be adapted to the old version of datasets
if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
return HubDatasetModuleFactoryWithScript(
@@ -1424,10 +1555,12 @@ class DatasetsWrapperHF:
logger.error(f'>> Error loading {path}: {e1}')
try:
# dynamic_modules_path was removed in datasets 4.0
_cached_factory_kwargs = {'cache_dir': cache_dir}
if _HAS_SCRIPT_LOADING:
_cached_factory_kwargs['dynamic_modules_path'] = dynamic_modules_path
return CachedDatasetModuleFactory(
path,
dynamic_modules_path=dynamic_modules_path,
cache_dir=cache_dir).get_module()
path, **_cached_factory_kwargs).get_module()
except Exception:
# If it's not in the cache, then it doesn't exist.
if isinstance(e1, OfflineModeIsEnabled):
@@ -1451,8 +1584,55 @@ class DatasetsWrapperHF:
f'any data file in the same directory.')
_hf_fs_open_original = None
def _hf_fs_open(self, path, mode='rb', **kwargs):
"""Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API.
The ModelScope tree API may report Size=0 for files. When HfFileSystem
caches this, AbstractBufferedFile treats the file as empty (0 bytes).
This wrapper detects size=0 for files opened in read mode and resolves
the actual size via a HEAD request before creating the file object.
"""
if mode == 'rb' and 'size' not in kwargs:
try:
resolved = self.resolve_path(path)
resolved_name = resolved.unresolve()
parent = self._parent(resolved_name)
cached_size = None
if parent in self.dircache:
for entry in self.dircache[parent]:
if entry['name'] == resolved_name and entry.get('type') == 'file':
cached_size = entry.get('size', -1)
break
if cached_size == 0:
url = hf_hub_url(
repo_id=resolved.repo_id,
revision=resolved.revision,
filename=resolved.path_in_repo,
repo_type=resolved.repo_type,
endpoint=self.endpoint,
)
headers = self._api._build_hf_headers()
resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30)
if resp.status_code == 200:
cl = resp.headers.get('Content-Length')
if cl:
actual_size = int(cl)
kwargs['size'] = actual_size
for entry in self.dircache.get(parent, []):
if entry['name'] == resolved_name:
entry['size'] = actual_size
break
except Exception:
pass
return _hf_fs_open_original(self, path, mode=mode, **kwargs)
@contextlib.contextmanager
def load_dataset_with_ctx(*args, **kwargs):
global _hf_fs_open_original
# Keep the original functions
hf_endpoint_origin = config.HF_ENDPOINT
@@ -1467,8 +1647,11 @@ def load_dataset_with_ctx(*args, **kwargs):
get_paths_info_origin = HfApi.get_paths_info
resolve_pattern_origin = data_files.resolve_pattern
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
# Script-based loading was removed in datasets 4.0
get_module_with_script_origin = (
HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None)
generate_from_dict_origin = features.generate_from_dict
hf_fs_open_origin = HfFileSystem._open
# Monkey patching with modelscope functions
config.HF_ENDPOINT = get_endpoint()
@@ -1483,8 +1666,11 @@ def load_dataset_with_ctx(*args, **kwargs):
HfApi.get_paths_info = _get_paths_info
data_files.resolve_pattern = _resolve_pattern
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
if _HAS_SCRIPT_LOADING:
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
features.generate_from_dict = generate_from_dict_ms
_hf_fs_open_original = hf_fs_open_origin
HfFileSystem._open = _hf_fs_open
streaming = kwargs.get('streaming', False)
@@ -1492,14 +1678,16 @@ def load_dataset_with_ctx(*args, **kwargs):
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
yield dataset_res
finally:
# Restore the original functions
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin
# Keep the context during the streaming iteration
_repo_tree_cache.clear()
HubApi._dataset_id_type_cache.clear()
HfFileSystem._open = hf_fs_open_origin
_hf_fs_open_original = None
if not streaming:
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin
# Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
@@ -1512,4 +1700,5 @@ def load_dataset_with_ctx(*args, **kwargs):
HfApi.get_paths_info = get_paths_info_origin
data_files.resolve_pattern = resolve_pattern_origin
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
if _HAS_SCRIPT_LOADING:
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=3.0.0,<=3.6.0
datasets>=4.0.0,<=4.6.1
einops
oss2
Pillow

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=3.0.0,<=3.6.0
datasets>=4.0.0,<=4.6.1
einops
Pillow
python-dateutil>=2.1