mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[Feat] dataset module refactor (#1623)
This commit is contained in:
@@ -11,7 +11,6 @@ import platform
|
||||
import re
|
||||
import shutil
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
@@ -1403,6 +1402,8 @@ class HubApi:
|
||||
raise_for_http_status(r)
|
||||
raise_on_error(r.json())
|
||||
|
||||
_dataset_id_type_cache: dict = {}
|
||||
|
||||
def get_dataset_id_and_type(self,
|
||||
dataset_name: str,
|
||||
namespace: str,
|
||||
@@ -1411,6 +1412,10 @@ class HubApi:
|
||||
""" Get the dataset id and type. """
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
cache_key = (namespace, dataset_name, endpoint)
|
||||
cached = HubApi._dataset_id_type_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
datahub_url = f'{endpoint}/api/v1/datasets/{namespace}/{dataset_name}'
|
||||
cookies = self.get_cookies(access_token=token)
|
||||
r = self.session.get(datahub_url, cookies=cookies)
|
||||
@@ -1418,6 +1423,7 @@ class HubApi:
|
||||
datahub_raise_on_error(datahub_url, resp, r)
|
||||
dataset_id = resp['Data']['Id']
|
||||
dataset_type = resp['Data']['Type']
|
||||
HubApi._dataset_id_type_cache[cache_key] = (dataset_id, dataset_type)
|
||||
return dataset_id, dataset_type
|
||||
|
||||
def list_repo_tree(self,
|
||||
@@ -1526,7 +1532,8 @@ class HubApi:
|
||||
page_number: int = 1,
|
||||
page_size: int = 100,
|
||||
endpoint: Optional[str] = None,
|
||||
token: Optional[str] = None):
|
||||
token: Optional[str] = None,
|
||||
dataset_hub_id: Optional[str] = None):
|
||||
"""
|
||||
Get the dataset files.
|
||||
|
||||
@@ -1539,19 +1546,23 @@ class HubApi:
|
||||
page_size (int): The number of items per page. Defaults to 100.
|
||||
endpoint (Optional[str]): The endpoint to use, defaults to None to use the endpoint specified in the class.
|
||||
token (Optional[str]): The access token.
|
||||
dataset_hub_id (Optional[str]): Pre-fetched dataset hub id. When provided,
|
||||
skips the internal ``get_dataset_id_and_type`` lookup. Useful in pagination
|
||||
loops to avoid redundant API calls per page.
|
||||
|
||||
Returns:
|
||||
List: The response containing the dataset repository tree information.
|
||||
e.g. [{'CommitId': None, 'CommitMessage': '...', 'Size': 0, 'Type': 'tree'}, ...]
|
||||
"""
|
||||
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
else:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
if dataset_hub_id is None:
|
||||
if is_relative_path(repo_id) and repo_id.count('/') == 1:
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
else:
|
||||
raise ValueError(f'Invalid repo_id: {repo_id} !')
|
||||
|
||||
dataset_hub_id, dataset_type = self.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
|
||||
dataset_hub_id, _ = self.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
|
||||
|
||||
if not endpoint:
|
||||
endpoint = self.endpoint
|
||||
@@ -1569,7 +1580,10 @@ class HubApi:
|
||||
resp = r.json()
|
||||
datahub_raise_on_error(datahub_url, resp, r)
|
||||
|
||||
return resp['Data']['Files']
|
||||
data = resp.get('Data')
|
||||
if data is None:
|
||||
return []
|
||||
return data.get('Files') or []
|
||||
|
||||
def get_dataset(
|
||||
self,
|
||||
@@ -2100,11 +2114,9 @@ class HubApi:
|
||||
repo_type: Optional[str] = REPO_TYPE_MODEL,
|
||||
revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
|
||||
endpoint: Optional[str] = None,
|
||||
max_retries: int = 3,
|
||||
timeout: int = 180,
|
||||
) -> CommitInfo:
|
||||
"""
|
||||
Create a commit on the ModelScope Hub with retry mechanism.
|
||||
Create a commit on the ModelScope Hub.
|
||||
|
||||
Args:
|
||||
repo_id (str): The repo id in the format of `owner_name/repo_name`.
|
||||
@@ -2117,14 +2129,14 @@ class HubApi:
|
||||
revision (Optional[str]): The branch or tag name. Defaults to `DEFAULT_REPOSITORY_REVISION`.
|
||||
endpoint (Optional[str]): The endpoint to use.
|
||||
In the format of `https://www.modelscope.cn` or 'https://www.modelscope.ai'
|
||||
max_retries (int): Number of max retry attempts (default: 3).
|
||||
timeout (int): Timeout for each request in seconds (default: 180).
|
||||
|
||||
Returns:
|
||||
CommitInfo: The commit info.
|
||||
|
||||
Raises:
|
||||
requests.exceptions.RequestException: If all retry attempts fail.
|
||||
ValueError: If the request fails with a 4xx client error.
|
||||
requests.exceptions.RequestException: If a network-level error occurs.
|
||||
"""
|
||||
if not repo_id:
|
||||
raise ValueError('Repo id cannot be empty!')
|
||||
@@ -2147,66 +2159,29 @@ class HubApi:
|
||||
commit_message=commit_message,
|
||||
)
|
||||
|
||||
# POST with retry mechanism
|
||||
last_exception = None
|
||||
for attempt in range(max_retries):
|
||||
response = self.session.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
data=json.dumps(payload),
|
||||
cookies=cookies,
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
if attempt > 0:
|
||||
logger.info(f'Attempt {attempt + 1} to create commit for {repo_id}...')
|
||||
response = requests.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
data=json.dumps(payload),
|
||||
cookies=cookies,
|
||||
timeout=timeout,
|
||||
)
|
||||
error_detail = response.json()
|
||||
except json.JSONDecodeError:
|
||||
error_detail = response.text
|
||||
error_msg = f'HTTP {response.status_code} error from {url}: {error_detail}'
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if response.status_code != 200:
|
||||
try:
|
||||
error_detail = response.json()
|
||||
except json.JSONDecodeError:
|
||||
error_detail = response.text
|
||||
|
||||
error_msg = (
|
||||
f'HTTP {response.status_code} error from {url}: '
|
||||
f'{error_detail}'
|
||||
)
|
||||
|
||||
# If server error (5xx), we can retry, otherwise (4xx) raise immediately
|
||||
if 500 <= response.status_code < 600:
|
||||
logger.warning(
|
||||
f'Server error on attempt {attempt + 1}: {error_msg}'
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'Client request failed: {error_msg}')
|
||||
else:
|
||||
resp = response.json()
|
||||
|
||||
oid = resp.get('Data', {}).get('oid', '')
|
||||
logger.info(f'Commit succeeded: {url}')
|
||||
return CommitInfo(
|
||||
commit_url=url,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
oid=oid,
|
||||
)
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
last_exception = e
|
||||
logger.warning(f'Request failed on attempt {attempt + 1}: {str(e)}')
|
||||
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
logger.error(f'Unexpected error on attempt {attempt + 1}: {str(e)}')
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(1)
|
||||
|
||||
# All retries exhausted
|
||||
raise requests.exceptions.RequestException(
|
||||
f'Failed to create commit after {max_retries} attempts. Last error: {last_exception}'
|
||||
resp = response.json()
|
||||
oid = resp.get('Data', {}).get('oid', '')
|
||||
logger.info(f'Commit succeeded: {url}')
|
||||
return CommitInfo(
|
||||
commit_url=url,
|
||||
commit_message=commit_message,
|
||||
commit_description=commit_description,
|
||||
oid=oid,
|
||||
)
|
||||
|
||||
def upload_file(
|
||||
@@ -2593,21 +2568,21 @@ class HubApi:
|
||||
|
||||
if isinstance(data, (str, Path)):
|
||||
with open(data, 'rb') as f:
|
||||
response = requests.put(
|
||||
response = self.session.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(f, pbar)
|
||||
)
|
||||
|
||||
elif isinstance(data, bytes):
|
||||
response = requests.put(
|
||||
response = self.session.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(io.BytesIO(data), pbar)
|
||||
)
|
||||
|
||||
elif isinstance(data, io.BufferedIOBase):
|
||||
response = requests.put(
|
||||
response = self.session.put(
|
||||
upload_object['url'],
|
||||
headers=headers,
|
||||
data=read_in_chunks(data, pbar)
|
||||
@@ -2664,7 +2639,7 @@ class HubApi:
|
||||
}
|
||||
|
||||
cookies = self.get_cookies(access_token=token, cookies_required=True)
|
||||
response = requests.post(
|
||||
response = self.session.post(
|
||||
url,
|
||||
headers=self.builder_headers(self.headers),
|
||||
data=json.dumps(payload),
|
||||
@@ -2907,6 +2882,9 @@ class HubApi:
|
||||
file_paths = [f['Path'] for f in files]
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
file_paths = []
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
_hub_id, _ = self.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint, token=token)
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
@@ -2919,6 +2897,7 @@ class HubApi:
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token,
|
||||
dataset_hub_id=_hub_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {str(e)}')
|
||||
|
||||
@@ -253,6 +253,11 @@ def _repo_file_download(
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
@@ -265,7 +270,8 @@ def _repo_file_download(
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
token=token)
|
||||
token=token,
|
||||
dataset_hub_id=_hub_id)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f'Get dataset: {repo_id} file list failed, error: {e}')
|
||||
|
||||
@@ -428,6 +428,10 @@ def _snapshot_download(
|
||||
|
||||
|
||||
def fetch_repo_files(_api, repo_id, revision, endpoint):
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
_hub_id, _ = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
|
||||
|
||||
page_number = 1
|
||||
page_size = 150
|
||||
repo_files = []
|
||||
@@ -441,7 +445,8 @@ def fetch_repo_files(_api, repo_id, revision, endpoint):
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint)
|
||||
endpoint=endpoint,
|
||||
dataset_hub_id=_hub_id)
|
||||
except Exception as e:
|
||||
logger.error(f'Error fetching dataset files: {e}')
|
||||
break
|
||||
|
||||
@@ -148,7 +148,6 @@ class OssDownloader(BaseDownloader):
|
||||
data_files=data_files,
|
||||
cache_dir=cache_dir,
|
||||
download_mode=download_mode.value,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**input_kwargs)
|
||||
else:
|
||||
self.dataset = self.data_files_manager.fetch_data_files(
|
||||
|
||||
@@ -87,7 +87,6 @@ class LocalDataLoaderManager(DataLoaderManager):
|
||||
cache_dir=cache_root_dir,
|
||||
download_mode=download_mode.value,
|
||||
streaming=use_streaming,
|
||||
trust_remote_code=trust_remote_code,
|
||||
**input_config_kwargs)
|
||||
raise f'Expected local data loader type: {LocalDataLoaderType.HF_DATA_LOADER.value}.'
|
||||
|
||||
@@ -130,7 +129,6 @@ class RemoteDataLoaderManager(DataLoaderManager):
|
||||
data_files=data_files,
|
||||
download_mode=download_mode_val,
|
||||
streaming=use_streaming,
|
||||
trust_remote_code=trust_remote_code,
|
||||
token=token,
|
||||
**input_config_kwargs)
|
||||
# download statistics
|
||||
|
||||
@@ -13,8 +13,8 @@ from datasets.filesystems import is_remote_filesystem
|
||||
from datasets.info import DatasetInfo
|
||||
from datasets.naming import camelcase_to_snakecase
|
||||
from datasets.packaged_modules import csv
|
||||
from datasets.utils.filelock import FileLock
|
||||
from datasets.utils.py_utils import map_nested
|
||||
from filelock import FileLock
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.msdatasets.context.dataset_context_config import \
|
||||
|
||||
@@ -5,7 +5,7 @@ import shutil
|
||||
from collections import defaultdict
|
||||
|
||||
import json
|
||||
from datasets.utils.filelock import FileLock
|
||||
from filelock import FileLock
|
||||
|
||||
from modelscope.hub.api import HubApi
|
||||
from modelscope.msdatasets.context.dataset_context_config import \
|
||||
|
||||
@@ -283,10 +283,15 @@ class MsDataset:
|
||||
return load_dataset(
|
||||
dataset_name,
|
||||
name=subset_name,
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
split=split,
|
||||
streaming=use_streaming,
|
||||
cache_dir=cache_dir,
|
||||
features=features,
|
||||
download_mode=download_mode.value,
|
||||
trust_remote_code=trust_remote_code,
|
||||
revision=version,
|
||||
token=token,
|
||||
streaming=use_streaming,
|
||||
**config_kwargs)
|
||||
|
||||
# Load from the modelscope hub
|
||||
|
||||
@@ -17,7 +17,16 @@ import requests
|
||||
from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
|
||||
DownloadConfig, DownloadManager, DownloadMode, Features,
|
||||
IterableDataset, IterableDatasetDict, Split,
|
||||
VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf)
|
||||
VerificationMode, Version, config, data_files, LargeList,
|
||||
Sequence as SequenceHf)
|
||||
|
||||
# In datasets 4.0+, Sequence was replaced by List as a feature type.
|
||||
# Use List as the base for ListMs when available, fall back to Sequence for <4.0.
|
||||
try:
|
||||
from datasets import List as DatasetList
|
||||
except ImportError:
|
||||
DatasetList = None
|
||||
|
||||
from datasets.features import features
|
||||
from datasets.features.features import _FEATURE_TYPES
|
||||
from datasets.data_files import (
|
||||
@@ -29,36 +38,63 @@ from datasets.download.streaming_download_manager import (
|
||||
from datasets.exceptions import DataFilesNotFoundError, DatasetNotFoundError
|
||||
from datasets.info import DatasetInfosDict
|
||||
from datasets.load import (
|
||||
ALL_ALLOWED_EXTENSIONS, BuilderConfigsParameters,
|
||||
BuilderConfigsParameters,
|
||||
CachedDatasetModuleFactory, DatasetModule,
|
||||
HubDatasetModuleFactoryWithoutScript,
|
||||
HubDatasetModuleFactoryWithParquetExport,
|
||||
HubDatasetModuleFactoryWithScript, LocalDatasetModuleFactoryWithoutScript,
|
||||
LocalDatasetModuleFactoryWithScript, PackagedDatasetModuleFactory,
|
||||
PackagedDatasetModuleFactory,
|
||||
create_builder_configs_from_metadata_configs, get_dataset_builder_class,
|
||||
import_main_class, infer_module_for_data_files, files_to_hash,
|
||||
_get_importable_file_path, resolve_trust_remote_code, _create_importable_file, _load_importable_file,
|
||||
init_dynamic_modules)
|
||||
import_main_class, infer_module_for_data_files)
|
||||
|
||||
# To compatible with datasets 4.0+
|
||||
try:
|
||||
from datasets.load import (
|
||||
HubDatasetModuleFactory as HubDatasetModuleFactoryWithoutScript,
|
||||
LocalDatasetModuleFactory as LocalDatasetModuleFactoryWithoutScript)
|
||||
except ImportError:
|
||||
from datasets.load import (
|
||||
HubDatasetModuleFactoryWithoutScript,
|
||||
LocalDatasetModuleFactoryWithoutScript)
|
||||
|
||||
# Script-based dataset loading was removed in datasets 4.0.
|
||||
# These APIs are conditionally imported for backward compatibility with <4.0.
|
||||
try:
|
||||
from datasets.load import (
|
||||
HubDatasetModuleFactoryWithScript,
|
||||
LocalDatasetModuleFactoryWithScript,
|
||||
resolve_trust_remote_code,
|
||||
_get_importable_file_path, _create_importable_file,
|
||||
_load_importable_file, init_dynamic_modules,
|
||||
files_to_hash)
|
||||
from datasets.utils.py_utils import get_imports
|
||||
_HAS_SCRIPT_LOADING = True
|
||||
except ImportError:
|
||||
_HAS_SCRIPT_LOADING = False
|
||||
|
||||
from datasets.naming import camelcase_to_snakecase
|
||||
from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
|
||||
_MODULE_TO_EXTENSIONS,
|
||||
_PACKAGED_DATASETS_MODULES)
|
||||
# ALL_ALLOWED_EXTENSIONS moved to datasets.packaged_modules in datasets 4.0
|
||||
try:
|
||||
from datasets.packaged_modules import _ALL_ALLOWED_EXTENSIONS as ALL_ALLOWED_EXTENSIONS
|
||||
except ImportError:
|
||||
from datasets.load import ALL_ALLOWED_EXTENSIONS
|
||||
from datasets.utils import file_utils
|
||||
from datasets.utils.file_utils import (_raise_if_offline_mode_is_enabled,
|
||||
cached_path, is_local_path,
|
||||
relative_to_absolute_path)
|
||||
from datasets.utils.info_utils import is_small_dataset
|
||||
from datasets.utils.metadata import MetadataConfigs
|
||||
from datasets.utils.py_utils import get_imports
|
||||
from datasets.utils.track import tracked_str
|
||||
|
||||
from fsspec import filesystem
|
||||
from fsspec.core import _un_chain
|
||||
from fsspec.utils import stringify_path
|
||||
from huggingface_hub import (DatasetCard, DatasetCardData)
|
||||
from huggingface_hub import (DatasetCard, DatasetCardData, hf_hub_url)
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
from huggingface_hub.hf_api import DatasetInfo as HfDatasetInfo
|
||||
from huggingface_hub.hf_api import HfApi, RepoFile, RepoFolder
|
||||
from huggingface_hub.hf_file_system import HfFileSystem
|
||||
from packaging import version
|
||||
|
||||
from modelscope import HubApi
|
||||
@@ -94,8 +130,13 @@ ExpandDatasetProperty_T = Literal[
|
||||
|
||||
|
||||
# Patch datasets features
|
||||
# In datasets 4.0+, the List type is the native feature type;
|
||||
# in datasets <4.0, Sequence (a dataclass) serves that role.
|
||||
_ListBase = DatasetList if DatasetList is not None else SequenceHf
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class ListMs(SequenceHf):
|
||||
class ListMs(_ListBase):
|
||||
"""Feature type for large list data composed of child feature data type.
|
||||
|
||||
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
|
||||
@@ -144,6 +185,15 @@ def generate_from_dict_ms(obj: Any):
|
||||
return {key: generate_from_dict_ms(value) for key, value in obj.items()}
|
||||
obj = dict(obj)
|
||||
_type = obj.pop('_type')
|
||||
|
||||
# Handle legacy 'Sequence' type for backward compatibility.
|
||||
# In datasets 4.0+, Sequence is a utility function (not a feature type),
|
||||
# so it may not be registered in _FEATURE_TYPES.
|
||||
if _type == 'Sequence':
|
||||
feature = obj.pop('feature')
|
||||
length = obj.get('length', -1)
|
||||
return SequenceHf(feature=generate_from_dict_ms(feature), length=length)
|
||||
|
||||
class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)
|
||||
|
||||
if class_type is None:
|
||||
@@ -155,9 +205,6 @@ def generate_from_dict_ms(obj: Any):
|
||||
if class_type == ListMs:
|
||||
feature = obj.pop('feature')
|
||||
return ListMs(generate_from_dict_ms(feature), **obj)
|
||||
if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict
|
||||
feature = obj.pop('feature')
|
||||
return SequenceHf(feature=generate_from_dict_ms(feature), **obj)
|
||||
|
||||
field_names = {f.name for f in fields(class_type)}
|
||||
return class_type(**{k: v for k, v in obj.items() if k in field_names})
|
||||
@@ -165,14 +212,12 @@ def generate_from_dict_ms(obj: Any):
|
||||
|
||||
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
||||
url_or_filename = str(url_or_filename)
|
||||
# for temp val
|
||||
revision = None
|
||||
if url_or_filename.startswith('hf://'):
|
||||
revision, url_or_filename = url_or_filename.split('@', 1)[-1].split('/', 1)
|
||||
if is_relative_path(url_or_filename):
|
||||
# append the relative path to the base_path
|
||||
# url_or_filename = url_or_path_join(self._base_path, url_or_filename)
|
||||
revision = revision or DEFAULT_DATASET_REVISION
|
||||
# hf:// URLs are handled natively by cached_path via HfApi.hf_hub_download,
|
||||
# which uses config.HF_ENDPOINT (already set to ModelScope endpoint).
|
||||
pass
|
||||
elif is_relative_path(url_or_filename):
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
# Note: make sure the FilePath is the last param
|
||||
params: dict = {'Source': 'SDK', 'Revision': revision, 'FilePath': url_or_filename}
|
||||
params: str = urlencode(params)
|
||||
@@ -274,6 +319,37 @@ def _dataset_info(
|
||||
return HfDatasetInfo(**data_info)
|
||||
|
||||
|
||||
_repo_tree_cache: Dict[tuple, List[Union[RepoFile, RepoFolder]]] = {}
|
||||
|
||||
|
||||
def _derive_from_recursive_cache(
|
||||
repo_id: str,
|
||||
revision: str,
|
||||
path_in_repo: str,
|
||||
recursive: bool,
|
||||
) -> Optional[List[Union[RepoFile, RepoFolder]]]:
|
||||
"""Try to derive results from a cached recursive root listing."""
|
||||
root_key = (repo_id, revision, '/', True)
|
||||
root_cached = _repo_tree_cache.get(root_key)
|
||||
if root_cached is None:
|
||||
return None
|
||||
|
||||
prefix = path_in_repo.strip('/') if path_in_repo and path_in_repo != '/' else ''
|
||||
results = []
|
||||
for item in root_cached:
|
||||
item_path = item.path
|
||||
if prefix:
|
||||
if not item_path.startswith(prefix + '/') and item_path != prefix:
|
||||
continue
|
||||
rel_path = item_path[len(prefix) + 1:] if item_path.startswith(prefix + '/') else ''
|
||||
else:
|
||||
rel_path = item_path
|
||||
if not recursive and '/' in rel_path:
|
||||
continue
|
||||
results.append(item)
|
||||
return results
|
||||
|
||||
|
||||
def _list_repo_tree(
|
||||
self,
|
||||
repo_id: str,
|
||||
@@ -286,41 +362,72 @@ def _list_repo_tree(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
) -> Iterable[Union[RepoFile, RepoFolder]]:
|
||||
|
||||
revision = revision or DEFAULT_DATASET_REVISION
|
||||
normalized_path = path_in_repo or '/'
|
||||
cache_key = (repo_id, revision, normalized_path, recursive)
|
||||
|
||||
cached = _repo_tree_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
yield from cached
|
||||
return
|
||||
|
||||
derived = _derive_from_recursive_cache(repo_id, revision, normalized_path, recursive)
|
||||
if derived is not None:
|
||||
_repo_tree_cache[cache_key] = derived
|
||||
yield from derived
|
||||
return
|
||||
|
||||
_api = HubApi(timeout=3 * 60, max_retries=3)
|
||||
endpoint = _api.get_endpoint_for_read(
|
||||
repo_id=repo_id, repo_type=REPO_TYPE_DATASET)
|
||||
|
||||
# List all files in the repo
|
||||
_owner, _dataset_name = repo_id.split('/')
|
||||
dataset_hub_id, _ = _api.get_dataset_id_and_type(
|
||||
dataset_name=_dataset_name, namespace=_owner, endpoint=endpoint)
|
||||
|
||||
results: List[Union[RepoFile, RepoFolder]] = []
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
# Larger page_size reduces the number of HTTP round-trips for big datasets.
|
||||
# Termination uses `not dataset_files` (empty page) so it is safe even if
|
||||
# the server silently caps the actual page size to a smaller value.
|
||||
page_size = 500
|
||||
max_pages = 10000
|
||||
while page_number <= max_pages:
|
||||
try:
|
||||
dataset_files = _api.get_dataset_files(
|
||||
repo_id=repo_id,
|
||||
revision=revision or DEFAULT_DATASET_REVISION,
|
||||
root_path=path_in_repo or '/',
|
||||
revision=revision,
|
||||
root_path=normalized_path,
|
||||
recursive=recursive,
|
||||
page_number=page_number,
|
||||
page_size=page_size,
|
||||
endpoint=endpoint,
|
||||
dataset_hub_id=dataset_hub_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Get dataset: {repo_id} file list failed, message: {e}')
|
||||
break
|
||||
|
||||
for file_info_d in dataset_files:
|
||||
path_info = {}
|
||||
path_info['type'] = 'directory' if file_info_d['Type'] == 'tree' else 'file'
|
||||
path_info['path'] = file_info_d['Path']
|
||||
path_info['size'] = file_info_d['Size']
|
||||
path_info['oid'] = file_info_d['Sha256']
|
||||
if not dataset_files:
|
||||
break
|
||||
|
||||
yield RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
|
||||
for file_info_d in dataset_files:
|
||||
path_info = {
|
||||
'type': 'directory' if file_info_d['Type'] == 'tree' else 'file',
|
||||
'path': file_info_d['Path'],
|
||||
'size': file_info_d['Size'],
|
||||
'oid': file_info_d['Sha256'],
|
||||
}
|
||||
item = RepoFile(**path_info) if path_info['type'] == 'file' else RepoFolder(**path_info)
|
||||
results.append(item)
|
||||
yield item
|
||||
|
||||
if len(dataset_files) < page_size:
|
||||
break
|
||||
page_number += 1
|
||||
|
||||
_repo_tree_cache[cache_key] = results
|
||||
|
||||
|
||||
def _get_paths_info(
|
||||
self,
|
||||
@@ -333,7 +440,19 @@ def _get_paths_info(
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
) -> List[Union[RepoFile, RepoFolder]]:
|
||||
|
||||
# Refer to func: `_list_repo_tree()`, for patching `HfApi.list_repo_tree`
|
||||
revision = revision or DEFAULT_DATASET_REVISION
|
||||
if isinstance(paths, str):
|
||||
paths = [paths]
|
||||
paths_set = set(paths)
|
||||
|
||||
# Search within any cached tree data (recursive root is the most comprehensive)
|
||||
root_key = (repo_id, revision, '/', True)
|
||||
root_cached = _repo_tree_cache.get(root_key)
|
||||
if root_cached is not None:
|
||||
matched = [item for item in root_cached if item.path in paths_set]
|
||||
if matched:
|
||||
return matched
|
||||
|
||||
repo_info_iter = self.list_repo_tree(
|
||||
repo_id=repo_id,
|
||||
recursive=False,
|
||||
@@ -831,6 +950,10 @@ def _download_additional_modules(
|
||||
|
||||
|
||||
def get_module_with_script(self) -> DatasetModule:
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
raise RuntimeError(
|
||||
'Script-based dataset loading is not supported with datasets>=4.0. '
|
||||
'Please convert the dataset to a script-free format (e.g. Parquet).')
|
||||
|
||||
repo_id: str = self.name
|
||||
_namespace, _dataset_name = repo_id.split('/')
|
||||
@@ -1000,9 +1123,12 @@ class DatasetsWrapperHF:
|
||||
) if not save_infos else VerificationMode.ALL_CHECKS)
|
||||
|
||||
if trust_remote_code:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.'
|
||||
)
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
logger.warning('trust_remote_code is ignored: script-based dataset loading '
|
||||
'is no longer supported with datasets>=4.0.')
|
||||
else:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.')
|
||||
|
||||
# Create a dataset builder
|
||||
builder_instance = DatasetsWrapperHF.load_dataset_builder(
|
||||
@@ -1017,7 +1143,7 @@ class DatasetsWrapperHF:
|
||||
revision=revision,
|
||||
token=token,
|
||||
storage_options=storage_options,
|
||||
trust_remote_code=trust_remote_code,
|
||||
trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
|
||||
_require_default_config_name=name is None,
|
||||
**config_kwargs,
|
||||
)
|
||||
@@ -1135,9 +1261,12 @@ class DatasetsWrapperHF:
|
||||
download_config.storage_options.update(storage_options)
|
||||
|
||||
if trust_remote_code:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.'
|
||||
)
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
logger.warning('trust_remote_code is ignored: script-based dataset loading '
|
||||
'is no longer supported with datasets>=4.0.')
|
||||
else:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.')
|
||||
|
||||
dataset_module = DatasetsWrapperHF.dataset_module_factory(
|
||||
path,
|
||||
@@ -1147,7 +1276,7 @@ class DatasetsWrapperHF:
|
||||
data_dir=data_dir,
|
||||
data_files=data_files,
|
||||
cache_dir=cache_dir,
|
||||
trust_remote_code=trust_remote_code,
|
||||
trust_remote_code=trust_remote_code if _HAS_SCRIPT_LOADING else None,
|
||||
_require_default_config_name=_require_default_config_name,
|
||||
_require_custom_configs=bool(config_kwargs),
|
||||
name=name,
|
||||
@@ -1250,9 +1379,12 @@ class DatasetsWrapperHF:
|
||||
# - if path has one "/" and is dataset repository on the HF hub without a python file
|
||||
# -> use a packaged module (csv, text etc.) based on content of the repository
|
||||
if trust_remote_code:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.'
|
||||
)
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
logger.warning('trust_remote_code is ignored: script-based dataset loading '
|
||||
'is no longer supported with datasets>=4.0.')
|
||||
else:
|
||||
logger.warning(f'Use trust_remote_code=True. Will invoke codes from {path}. Please make sure '
|
||||
'that you can trust the external codes.')
|
||||
|
||||
# Try packaged
|
||||
if path in _PACKAGED_DATASETS_MODULES:
|
||||
@@ -1263,9 +1395,13 @@ class DatasetsWrapperHF:
|
||||
download_config=download_config,
|
||||
download_mode=download_mode,
|
||||
).get_module()
|
||||
# Try locally
|
||||
# Try locally with script (requires datasets <4.0)
|
||||
elif path.endswith(filename):
|
||||
if os.path.isfile(path):
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
raise RuntimeError(
|
||||
f'Script-based dataset loading ({path}) is not supported with datasets>=4.0. '
|
||||
'Please convert the dataset to a script-free format (e.g. Parquet).')
|
||||
return LocalDatasetModuleFactoryWithScript(
|
||||
path,
|
||||
download_mode=download_mode,
|
||||
@@ -1277,6 +1413,10 @@ class DatasetsWrapperHF:
|
||||
f"Couldn't find a dataset script at {relative_to_absolute_path(path)}"
|
||||
)
|
||||
elif os.path.isfile(combined_path):
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
raise RuntimeError(
|
||||
f'Script-based dataset loading ({combined_path}) is not supported with datasets>=4.0. '
|
||||
'Please convert the dataset to a script-free format (e.g. Parquet).')
|
||||
return LocalDatasetModuleFactoryWithScript(
|
||||
combined_path,
|
||||
download_mode=download_mode,
|
||||
@@ -1342,24 +1482,8 @@ class DatasetsWrapperHF:
|
||||
sibling.rfilename for sibling in dataset_info.siblings
|
||||
]: # contains a dataset script
|
||||
|
||||
# fs = HfFileSystem(
|
||||
# endpoint=config.HF_ENDPOINT,
|
||||
# token=download_config.token)
|
||||
|
||||
# TODO
|
||||
can_load_config_from_parquet_export = False
|
||||
# if _require_custom_configs:
|
||||
# can_load_config_from_parquet_export = False
|
||||
# elif _require_default_config_name:
|
||||
# with fs.open(
|
||||
# f'datasets/{path}/{filename}',
|
||||
# 'r',
|
||||
# revision=revision,
|
||||
# encoding='utf-8') as f:
|
||||
# can_load_config_from_parquet_export = 'DEFAULT_CONFIG_NAME' not in f.read(
|
||||
# )
|
||||
# else:
|
||||
# can_load_config_from_parquet_export = True
|
||||
if config.USE_PARQUET_EXPORT and can_load_config_from_parquet_export:
|
||||
# If the parquet export is ready (parquet files + info available for the current sha),
|
||||
# we can use it instead
|
||||
@@ -1379,7 +1503,14 @@ class DatasetsWrapperHF:
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# Otherwise we must use the dataset script if the user trusts it
|
||||
# Otherwise we must use the dataset script if the user trusts it.
|
||||
# Script-based loading was removed in datasets 4.0.
|
||||
if not _HAS_SCRIPT_LOADING:
|
||||
raise RuntimeError(
|
||||
f"Dataset '{path}' contains a loading script but script-based dataset loading "
|
||||
'is not supported with datasets>=4.0. Please convert the dataset to a '
|
||||
'script-free format (e.g. Parquet).')
|
||||
|
||||
# To be adapted to the old version of datasets
|
||||
if has_attr_in_class(HubDatasetModuleFactoryWithScript, 'revision'):
|
||||
return HubDatasetModuleFactoryWithScript(
|
||||
@@ -1424,10 +1555,12 @@ class DatasetsWrapperHF:
|
||||
logger.error(f'>> Error loading {path}: {e1}')
|
||||
|
||||
try:
|
||||
# dynamic_modules_path was removed in datasets 4.0
|
||||
_cached_factory_kwargs = {'cache_dir': cache_dir}
|
||||
if _HAS_SCRIPT_LOADING:
|
||||
_cached_factory_kwargs['dynamic_modules_path'] = dynamic_modules_path
|
||||
return CachedDatasetModuleFactory(
|
||||
path,
|
||||
dynamic_modules_path=dynamic_modules_path,
|
||||
cache_dir=cache_dir).get_module()
|
||||
path, **_cached_factory_kwargs).get_module()
|
||||
except Exception:
|
||||
# If it's not in the cache, then it doesn't exist.
|
||||
if isinstance(e1, OfflineModeIsEnabled):
|
||||
@@ -1451,8 +1584,55 @@ class DatasetsWrapperHF:
|
||||
f'any data file in the same directory.')
|
||||
|
||||
|
||||
_hf_fs_open_original = None
|
||||
|
||||
|
||||
def _hf_fs_open(self, path, mode='rb', **kwargs):
|
||||
"""Wrapper for HfFileSystem._open that fixes size=0 from ModelScope API.
|
||||
|
||||
The ModelScope tree API may report Size=0 for files. When HfFileSystem
|
||||
caches this, AbstractBufferedFile treats the file as empty (0 bytes).
|
||||
This wrapper detects size=0 for files opened in read mode and resolves
|
||||
the actual size via a HEAD request before creating the file object.
|
||||
"""
|
||||
if mode == 'rb' and 'size' not in kwargs:
|
||||
try:
|
||||
resolved = self.resolve_path(path)
|
||||
resolved_name = resolved.unresolve()
|
||||
parent = self._parent(resolved_name)
|
||||
cached_size = None
|
||||
if parent in self.dircache:
|
||||
for entry in self.dircache[parent]:
|
||||
if entry['name'] == resolved_name and entry.get('type') == 'file':
|
||||
cached_size = entry.get('size', -1)
|
||||
break
|
||||
if cached_size == 0:
|
||||
url = hf_hub_url(
|
||||
repo_id=resolved.repo_id,
|
||||
revision=resolved.revision,
|
||||
filename=resolved.path_in_repo,
|
||||
repo_type=resolved.repo_type,
|
||||
endpoint=self.endpoint,
|
||||
)
|
||||
headers = self._api._build_hf_headers()
|
||||
resp = requests.head(url, headers=headers, allow_redirects=True, timeout=30)
|
||||
if resp.status_code == 200:
|
||||
cl = resp.headers.get('Content-Length')
|
||||
if cl:
|
||||
actual_size = int(cl)
|
||||
kwargs['size'] = actual_size
|
||||
for entry in self.dircache.get(parent, []):
|
||||
if entry['name'] == resolved_name:
|
||||
entry['size'] = actual_size
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
return _hf_fs_open_original(self, path, mode=mode, **kwargs)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def load_dataset_with_ctx(*args, **kwargs):
|
||||
global _hf_fs_open_original
|
||||
|
||||
# Keep the original functions
|
||||
hf_endpoint_origin = config.HF_ENDPOINT
|
||||
@@ -1467,8 +1647,11 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
get_paths_info_origin = HfApi.get_paths_info
|
||||
resolve_pattern_origin = data_files.resolve_pattern
|
||||
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
|
||||
get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
|
||||
# Script-based loading was removed in datasets 4.0
|
||||
get_module_with_script_origin = (
|
||||
HubDatasetModuleFactoryWithScript.get_module if _HAS_SCRIPT_LOADING else None)
|
||||
generate_from_dict_origin = features.generate_from_dict
|
||||
hf_fs_open_origin = HfFileSystem._open
|
||||
|
||||
# Monkey patching with modelscope functions
|
||||
config.HF_ENDPOINT = get_endpoint()
|
||||
@@ -1483,8 +1666,11 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
HfApi.get_paths_info = _get_paths_info
|
||||
data_files.resolve_pattern = _resolve_pattern
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
|
||||
if _HAS_SCRIPT_LOADING:
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
|
||||
features.generate_from_dict = generate_from_dict_ms
|
||||
_hf_fs_open_original = hf_fs_open_origin
|
||||
HfFileSystem._open = _hf_fs_open
|
||||
|
||||
streaming = kwargs.get('streaming', False)
|
||||
|
||||
@@ -1492,14 +1678,16 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
dataset_res = DatasetsWrapperHF.load_dataset(*args, **kwargs)
|
||||
yield dataset_res
|
||||
finally:
|
||||
# Restore the original functions
|
||||
config.HF_ENDPOINT = hf_endpoint_origin
|
||||
file_utils.get_from_cache = get_from_cache_origin
|
||||
features.generate_from_dict = generate_from_dict_origin
|
||||
# Keep the context during the streaming iteration
|
||||
_repo_tree_cache.clear()
|
||||
HubApi._dataset_id_type_cache.clear()
|
||||
|
||||
HfFileSystem._open = hf_fs_open_origin
|
||||
_hf_fs_open_original = None
|
||||
|
||||
if not streaming:
|
||||
config.HF_ENDPOINT = hf_endpoint_origin
|
||||
file_utils.get_from_cache = get_from_cache_origin
|
||||
features.generate_from_dict = generate_from_dict_origin
|
||||
|
||||
# Compatible with datasets 2.18.0
|
||||
if hasattr(DownloadManager, '_download'):
|
||||
@@ -1512,4 +1700,5 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
HfApi.get_paths_info = get_paths_info_origin
|
||||
data_files.resolve_pattern = resolve_pattern_origin
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||
if _HAS_SCRIPT_LOADING:
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=3.0.0,<=3.6.0
|
||||
datasets>=4.0.0,<=4.6.1
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=3.0.0,<=3.6.0
|
||||
datasets>=4.0.0,<=4.6.1
|
||||
einops
|
||||
Pillow
|
||||
python-dateutil>=2.1
|
||||
|
||||
Reference in New Issue
Block a user