mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[Fix] fix dataset util (#1645)
This commit is contained in:
@@ -130,40 +130,7 @@ ExpandDatasetProperty_T = Literal[
|
||||
|
||||
|
||||
# Patch datasets features
|
||||
# In datasets 4.0+, the List type is the native feature type;
|
||||
# in datasets <4.0, Sequence (a dataclass) serves that role.
|
||||
_ListBase = DatasetList if DatasetList is not None else SequenceHf
|
||||
|
||||
|
||||
@dataclass(repr=False)
|
||||
class ListMs(_ListBase):
|
||||
"""Feature type for large list data composed of child feature data type.
|
||||
|
||||
It is backed by `pyarrow.ListType`, which uses 32-bit offsets or a fixed length.
|
||||
|
||||
Args:
|
||||
feature ([`FeatureType`]):
|
||||
Child feature data type of each item within the large list.
|
||||
length (optional `int`, default to -1):
|
||||
Length of the list if it is fixed.
|
||||
Defaults to -1 which means an arbitrary length.
|
||||
"""
|
||||
|
||||
feature: Any
|
||||
length: int = -1
|
||||
id: Optional[str] = field(default=None, repr=False)
|
||||
# Automatically constructed
|
||||
pa_type: ClassVar[Any] = None
|
||||
_type: str = field(default='List', init=False, repr=False)
|
||||
|
||||
def __repr__(self):
|
||||
if self.length != -1:
|
||||
return f'{type(self).__name__}({self.feature}, length={self.length})'
|
||||
else:
|
||||
return f'{type(self).__name__}({self.feature})'
|
||||
|
||||
|
||||
_FEATURE_TYPES['List'] = ListMs
|
||||
_NativeList = DatasetList if DatasetList is not None else SequenceHf
|
||||
|
||||
|
||||
def generate_from_dict_ms(obj: Any):
|
||||
@@ -202,9 +169,10 @@ def generate_from_dict_ms(obj: Any):
|
||||
if class_type == LargeList:
|
||||
feature = obj.pop('feature')
|
||||
return LargeList(generate_from_dict_ms(feature), **obj)
|
||||
if class_type == ListMs:
|
||||
# Handle the native List type (datasets 4.0+) as well as Sequence-based
|
||||
if _NativeList is not None and (class_type is _NativeList or issubclass(class_type, _NativeList)):
|
||||
feature = obj.pop('feature')
|
||||
return ListMs(generate_from_dict_ms(feature), **obj)
|
||||
return _NativeList(generate_from_dict_ms(feature), **obj)
|
||||
|
||||
field_names = {f.name for f in fields(class_type)}
|
||||
return class_type(**{k: v for k, v in obj.items() if k in field_names})
|
||||
@@ -213,9 +181,30 @@ def generate_from_dict_ms(obj: Any):
|
||||
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
||||
url_or_filename = str(url_or_filename)
|
||||
if url_or_filename.startswith('hf://'):
|
||||
# hf:// URLs are handled natively by cached_path via HfApi.hf_hub_download,
|
||||
# which uses config.HF_ENDPOINT (already set to ModelScope endpoint).
|
||||
pass
|
||||
# hf:// URLs (e.g. hf://datasets/{owner}/{name}@{revision}/{file_path})
|
||||
hf_path = url_or_filename[len('hf://'):]
|
||||
# Strip leading resource type prefix (e.g. "datasets/")
|
||||
for _prefix in ('datasets/', 'models/'):
|
||||
if hf_path.startswith(_prefix):
|
||||
hf_path = hf_path[len(_prefix):]
|
||||
break
|
||||
# Extract revision and file_path from "{owner}/{name}@{revision}/{file_path}"
|
||||
if '@' in hf_path:
|
||||
at_idx = hf_path.index('@')
|
||||
after_at = hf_path[at_idx + 1:]
|
||||
slash_idx = after_at.find('/')
|
||||
if slash_idx == -1:
|
||||
revision = after_at
|
||||
file_path = ''
|
||||
else:
|
||||
revision = after_at[:slash_idx]
|
||||
file_path = after_at[slash_idx + 1:]
|
||||
else:
|
||||
parts = hf_path.split('/', 2)
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
file_path = parts[2] if len(parts) > 2 else ''
|
||||
params = urlencode({'Source': 'SDK', 'Revision': revision, 'FilePath': file_path})
|
||||
url_or_filename = self._base_path + params
|
||||
elif is_relative_path(url_or_filename):
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
# Note: make sure the FilePath is the last param
|
||||
|
||||
@@ -7,6 +7,7 @@ import unittest
|
||||
|
||||
from modelscope.hub.file_download import dataset_file_download
|
||||
from modelscope.hub.snapshot_download import dataset_snapshot_download
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class DownloadDatasetTest(unittest.TestCase):
|
||||
@@ -14,6 +15,7 @@ class DownloadDatasetTest(unittest.TestCase):
|
||||
def setUp(self):
|
||||
pass
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_dataset_file_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
@@ -67,6 +69,7 @@ class DownloadDatasetTest(unittest.TestCase):
|
||||
file_modify_time2 = os.path.getmtime(cache_file_path)
|
||||
assert file_modify_time == file_modify_time2
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_dataset_snapshot_download(self):
|
||||
dataset_id = 'citest/test_dataset_download'
|
||||
file_path = 'open_qa.jsonl'
|
||||
|
||||
Reference in New Issue
Block a user