diff --git a/modelscope/msdatasets/utils/hf_datasets_util.py b/modelscope/msdatasets/utils/hf_datasets_util.py index 224964f4..9053d062 100644 --- a/modelscope/msdatasets/utils/hf_datasets_util.py +++ b/modelscope/msdatasets/utils/hf_datasets_util.py @@ -6,9 +6,10 @@ import contextlib import inspect import os import warnings +from dataclasses import dataclass, field, fields from functools import partial from pathlib import Path -from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar from urllib.parse import urlencode @@ -16,7 +17,9 @@ import requests from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict, DownloadConfig, DownloadManager, DownloadMode, Features, IterableDataset, IterableDatasetDict, Split, - VerificationMode, Version, config, data_files) + VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf) +from datasets.features import features +from datasets.features.features import _FEATURE_TYPES from datasets.data_files import ( FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError, _get_data_files_patterns, _is_inside_unrequested_special_dir, @@ -49,6 +52,7 @@ from datasets.utils.info_utils import is_small_dataset from datasets.utils.metadata import MetadataConfigs from datasets.utils.py_utils import get_imports from datasets.utils.track import tracked_str + from fsspec import filesystem from fsspec.core import _un_chain from fsspec.utils import stringify_path @@ -62,7 +66,7 @@ from modelscope import HubApi from modelscope.hub.utils.utils import get_endpoint from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms from modelscope.utils.config_ds import MS_DATASETS_CACHE -from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET +from modelscope.utils.constant import DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET from modelscope.utils.import_utils import has_attr_in_class from modelscope.utils.logger import get_logger @@ -89,6 +93,76 @@ ExpandDatasetProperty_T = Literal[ ] +# Patch datasets features +@dataclass(repr=False) +class ListMs(SequenceHf): + """Feature type for large list data composed of child feature data type. + + It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length. + + Args: + feature ([`FeatureType`]): + Child feature data type of each item within the large list. + length (optional `int`, default to -1): + Length of the list if it is fixed. + Defaults to -1 which means an arbitrary length. + """ + + feature: Any + length: int = -1 + id: Optional[str] = field(default=None, repr=False) + # Automatically constructed + pa_type: ClassVar[Any] = None + _type: str = field(default='List', init=False, repr=False) + + def __repr__(self): + if self.length != -1: + return f'{type(self).__name__}({self.feature}, length={self.length})' + else: + return f'{type(self).__name__}({self.feature})' + + +_FEATURE_TYPES['List'] = ListMs + + +def generate_from_dict_ms(obj: Any): + """Regenerate the nested feature object from a deserialized dict. + We use the '_type' fields to get the dataclass name to load. + + generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax + to define features from deserialized JSON dictionaries. This function is used in particular when deserializing + a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to + :meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any + mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes + that :class:`Value` automatically performs. + """ + # Nested structures: we allow dict, list/tuples, sequences + if isinstance(obj, list): + return [generate_from_dict_ms(value) for value in obj] + # Otherwise we have a dict or a dataclass + if '_type' not in obj or isinstance(obj['_type'], dict): + return {key: generate_from_dict_ms(value) for key, value in obj.items()} + obj = dict(obj) + _type = obj.pop('_type') + class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None) + + if class_type is None: + raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}") + + if class_type == LargeList: + feature = obj.pop('feature') + return LargeList(generate_from_dict_ms(feature), **obj) + if class_type == ListMs: + feature = obj.pop('feature') + return ListMs(generate_from_dict_ms(feature), **obj) + if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict + feature = obj.pop('feature') + return SequenceHf(feature=generate_from_dict_ms(feature), **obj) + + field_names = {f.name for f in fields(class_type)} + return class_type(**{k: v for k, v in obj.items() if k in field_names}) + + def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str: url_or_filename = str(url_or_filename) # for temp val @@ -1377,6 +1451,7 @@ def load_dataset_with_ctx(*args, **kwargs): resolve_pattern_origin = data_files.resolve_pattern get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module + generate_from_dict_origin = features.generate_from_dict # Monkey patching with modelscope functions config.HF_ENDPOINT = get_endpoint() @@ -1392,6 +1467,7 @@ def load_dataset_with_ctx(*args, **kwargs): data_files.resolve_pattern = _resolve_pattern HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script HubDatasetModuleFactoryWithScript.get_module = get_module_with_script + features.generate_from_dict = generate_from_dict_ms streaming = kwargs.get('streaming', False) @@ -1402,6 +1478,7 @@ def load_dataset_with_ctx(*args, **kwargs): # Restore the original functions config.HF_ENDPOINT = hf_endpoint_origin file_utils.get_from_cache = get_from_cache_origin + features.generate_from_dict = generate_from_dict_origin # Keep the context during the streaming iteration if not streaming: config.HF_ENDPOINT = hf_endpoint_origin