fix features for datasets<=3.6.0

This commit is contained in:
xingjun.wxj
2025-08-06 15:10:56 +08:00
parent 595f3ea263
commit 924ad0822a

View File

@@ -6,9 +6,10 @@ import contextlib
import inspect import inspect
import os import os
import warnings import warnings
from dataclasses import dataclass, field, fields
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal, Any, ClassVar
from urllib.parse import urlencode from urllib.parse import urlencode
@@ -16,7 +17,9 @@ import requests
from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict, from datasets import (BuilderConfig, Dataset, DatasetBuilder, DatasetDict,
DownloadConfig, DownloadManager, DownloadMode, Features, DownloadConfig, DownloadManager, DownloadMode, Features,
IterableDataset, IterableDatasetDict, Split, IterableDataset, IterableDatasetDict, Split,
VerificationMode, Version, config, data_files) VerificationMode, Version, config, data_files, LargeList, Sequence as SequenceHf)
from datasets.features import features
from datasets.features.features import _FEATURE_TYPES
from datasets.data_files import ( from datasets.data_files import (
FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError, FILES_TO_IGNORE, DataFilesDict, EmptyDatasetError,
_get_data_files_patterns, _is_inside_unrequested_special_dir, _get_data_files_patterns, _is_inside_unrequested_special_dir,
@@ -49,6 +52,7 @@ from datasets.utils.info_utils import is_small_dataset
from datasets.utils.metadata import MetadataConfigs from datasets.utils.metadata import MetadataConfigs
from datasets.utils.py_utils import get_imports from datasets.utils.py_utils import get_imports
from datasets.utils.track import tracked_str from datasets.utils.track import tracked_str
from fsspec import filesystem from fsspec import filesystem
from fsspec.core import _un_chain from fsspec.core import _un_chain
from fsspec.utils import stringify_path from fsspec.utils import stringify_path
@@ -62,7 +66,7 @@ from modelscope import HubApi
from modelscope.hub.utils.utils import get_endpoint from modelscope.hub.utils.utils import get_endpoint
from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms from modelscope.msdatasets.utils.hf_file_utils import get_from_cache_ms
from modelscope.utils.config_ds import MS_DATASETS_CACHE from modelscope.utils.config_ds import MS_DATASETS_CACHE
from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET from modelscope.utils.constant import DEFAULT_DATASET_REVISION, REPO_TYPE_DATASET
from modelscope.utils.import_utils import has_attr_in_class from modelscope.utils.import_utils import has_attr_in_class
from modelscope.utils.logger import get_logger from modelscope.utils.logger import get_logger
@@ -89,6 +93,76 @@ ExpandDatasetProperty_T = Literal[
] ]
# Patch datasets features
@dataclass(repr=False)
class ListMs(SequenceHf):
"""Feature type for large list data composed of child feature data type.
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
Args:
feature ([`FeatureType`]):
Child feature data type of each item within the large list.
length (optional `int`, default to -1):
Length of the list if it is fixed.
Defaults to -1 which means an arbitrary length.
"""
feature: Any
length: int = -1
id: Optional[str] = field(default=None, repr=False)
# Automatically constructed
pa_type: ClassVar[Any] = None
_type: str = field(default='List', init=False, repr=False)
def __repr__(self):
if self.length != -1:
return f'{type(self).__name__}({self.feature}, length={self.length})'
else:
return f'{type(self).__name__}({self.feature})'
_FEATURE_TYPES['List'] = ListMs
def generate_from_dict_ms(obj: Any):
"""Regenerate the nested feature object from a deserialized dict.
We use the '_type' fields to get the dataclass name to load.
generate_from_dict is the recursive helper for Features.from_dict, and allows for a convenient constructor syntax
to define features from deserialized JSON dictionaries. This function is used in particular when deserializing
a :class:`DatasetInfo` that was dumped to a JSON object. This acts as an analogue to
:meth:`Features.from_arrow_schema` and handles the recursive field-by-field instantiation, but doesn't require any
mapping to/from pyarrow, except for the fact that it takes advantage of the mapping of pyarrow primitive dtypes
that :class:`Value` automatically performs.
"""
# Nested structures: we allow dict, list/tuples, sequences
if isinstance(obj, list):
return [generate_from_dict_ms(value) for value in obj]
# Otherwise we have a dict or a dataclass
if '_type' not in obj or isinstance(obj['_type'], dict):
return {key: generate_from_dict_ms(value) for key, value in obj.items()}
obj = dict(obj)
_type = obj.pop('_type')
class_type = _FEATURE_TYPES.get(_type, None) or globals().get(_type, None)
if class_type is None:
raise ValueError(f"Feature type '{_type}' not found. Available feature types: {list(_FEATURE_TYPES.keys())}")
if class_type == LargeList:
feature = obj.pop('feature')
return LargeList(generate_from_dict_ms(feature), **obj)
if class_type == ListMs:
feature = obj.pop('feature')
return ListMs(generate_from_dict_ms(feature), **obj)
if class_type == SequenceHf: # backward compatibility, this translates to a List or a dict
feature = obj.pop('feature')
return SequenceHf(feature=generate_from_dict_ms(feature), **obj)
field_names = {f.name for f in fields(class_type)}
return class_type(**{k: v for k, v in obj.items() if k in field_names})
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str: def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename) url_or_filename = str(url_or_filename)
# for temp val # for temp val
@@ -1377,6 +1451,7 @@ def load_dataset_with_ctx(*args, **kwargs):
resolve_pattern_origin = data_files.resolve_pattern resolve_pattern_origin = data_files.resolve_pattern
get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module get_module_without_script_origin = HubDatasetModuleFactoryWithoutScript.get_module
get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module get_module_with_script_origin = HubDatasetModuleFactoryWithScript.get_module
generate_from_dict_origin = features.generate_from_dict
# Monkey patching with modelscope functions # Monkey patching with modelscope functions
config.HF_ENDPOINT = get_endpoint() config.HF_ENDPOINT = get_endpoint()
@@ -1392,6 +1467,7 @@ def load_dataset_with_ctx(*args, **kwargs):
data_files.resolve_pattern = _resolve_pattern data_files.resolve_pattern = _resolve_pattern
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script HubDatasetModuleFactoryWithScript.get_module = get_module_with_script
features.generate_from_dict = generate_from_dict_ms
streaming = kwargs.get('streaming', False) streaming = kwargs.get('streaming', False)
@@ -1402,6 +1478,7 @@ def load_dataset_with_ctx(*args, **kwargs):
# Restore the original functions # Restore the original functions
config.HF_ENDPOINT = hf_endpoint_origin config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin file_utils.get_from_cache = get_from_cache_origin
features.generate_from_dict = generate_from_dict_origin
# Keep the context during the streaming iteration # Keep the context during the streaming iteration
if not streaming: if not streaming:
config.HF_ENDPOINT = hf_endpoint_origin config.HF_ENDPOINT = hf_endpoint_origin