[Fix] fix dataset util (#1645)

This commit is contained in:
Xingjun.Wang
2026-03-13 10:36:11 +08:00
committed by GitHub
parent 467a2206e3
commit f4dbe65110
2 changed files with 31 additions and 39 deletions

View File

@@ -130,40 +130,7 @@ ExpandDatasetProperty_T = Literal[
# Patch datasets features
# In datasets 4.0+, the List type is the native feature type;
# in datasets <4.0, Sequence (a dataclass) serves that role.
_ListBase = DatasetList if DatasetList is not None else SequenceHf
@dataclass(repr=False)
class ListMs(_ListBase):
"""Feature type for large list data composed of child feature data type.
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
Args:
feature ([`FeatureType`]):
Child feature data type of each item within the large list.
length (optional `int`, default to -1):
Length of the list if it is fixed.
Defaults to -1 which means an arbitrary length.
"""
feature: Any
length: int = -1
id: Optional[str] = field(default=None, repr=False)
# Automatically constructed
pa_type: ClassVar[Any] = None
_type: str = field(default='List', init=False, repr=False)
def __repr__(self):
if self.length != -1:
return f'{type(self).__name__}({self.feature}, length={self.length})'
else:
return f'{type(self).__name__}({self.feature})'
_FEATURE_TYPES['List'] = ListMs
_NativeList = DatasetList if DatasetList is not None else SequenceHf
def generate_from_dict_ms(obj: Any):
@@ -202,9 +169,10 @@ def generate_from_dict_ms(obj: Any):
if class_type == LargeList:
feature = obj.pop('feature')
return LargeList(generate_from_dict_ms(feature), **obj)
if class_type == ListMs:
# Handle the native List type (datasets 4.0+) as well as Sequence-based
if _NativeList is not None and (class_type is _NativeList or issubclass(class_type, _NativeList)):
feature = obj.pop('feature')
return ListMs(generate_from_dict_ms(feature), **obj)
return _NativeList(generate_from_dict_ms(feature), **obj)
field_names = {f.name for f in fields(class_type)}
return class_type(**{k: v for k, v in obj.items() if k in field_names})
@@ -213,9 +181,30 @@ def generate_from_dict_ms(obj: Any):
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
if url_or_filename.startswith('hf://'):
# hf:// URLs are handled natively by cached_path via HfApi.hf_hub_download,
# which uses config.HF_ENDPOINT (already set to ModelScope endpoint).
pass
# hf:// URLs (e.g. hf://datasets/{owner}/{name}@{revision}/{file_path})
hf_path = url_or_filename[len('hf://'):]
# Strip leading resource type prefix (e.g. "datasets/")
for _prefix in ('datasets/', 'models/'):
if hf_path.startswith(_prefix):
hf_path = hf_path[len(_prefix):]
break
# Extract revision and file_path from "{owner}/{name}@{revision}/{file_path}"
if '@' in hf_path:
at_idx = hf_path.index('@')
after_at = hf_path[at_idx + 1:]
slash_idx = after_at.find('/')
if slash_idx == -1:
revision = after_at
file_path = ''
else:
revision = after_at[:slash_idx]
file_path = after_at[slash_idx + 1:]
else:
parts = hf_path.split('/', 2)
revision = DEFAULT_DATASET_REVISION
file_path = parts[2] if len(parts) > 2 else ''
params = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path})
url_or_filename = self._base_path + params
elif is_relative_path(url_or_filename):
revision = DEFAULT_DATASET_REVISION
# Note: make sure the FilePath is the last param

View File

@@ -7,6 +7,7 @@ import unittest
from modelscope.hub.file_download import dataset_file_download
from modelscope.hub.snapshot_download import dataset_snapshot_download
from modelscope.utils.test_utils import test_level
class DownloadDatasetTest(unittest.TestCase):
@@ -14,6 +15,7 @@ class DownloadDatasetTest(unittest.TestCase):
def setUp(self):
pass
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_dataset_file_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'
@@ -67,6 +69,7 @@ class DownloadDatasetTest(unittest.TestCase):
file_modify_time2 = os.path.getmtime(cache_file_path)
assert file_modify_time == file_modify_time2
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_dataset_snapshot_download(self):
dataset_id = 'citest/test_dataset_download'
file_path = 'open_qa.jsonl'