Upgrade datasets (#921)

* del _datasets_server import in hf_dataset_util

* fix streaming for youku-mplug and adopt latest datasets

* fix download config copy

* update ut

* add youku in test_general_datasets

* update UT for general dataset

* adapt to datasets version: 2.19.0 or later

* add assert for youku data UT

* fix disable_tqdm in some functions for 2.19.0 or later

* update get_module_with_script

* set trust_remote_code is True in load_dataset_with_ctx

* update print info

* update requirements for datasets version restriction

* fix _dataset_info

* add pillow

* update comments

* update comment

* reuse _download function in DataDownloadManager

* remove unused code

* update test_run_modelhub in Human3DAnimationTest

* set datasets>=2.18.0
This commit is contained in:
Xingjun.Wang
2024-07-23 22:26:12 +08:00
committed by GitHub
parent 4e2555c5a3
commit 210ab40c54
13 changed files with 163 additions and 102 deletions

View File

@@ -1,2 +1,2 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .ms_dataset import MsDataset
from modelscope.msdatasets.ms_dataset import MsDataset

View File

@@ -149,6 +149,7 @@ class NativeIterableDataset(IterableDataset):
if isinstance(ex_cache_path, str):
ex_cache_path = [ex_cache_path]
ret[k] = ex_cache_path
ret[k.strip(':FILE')] = v
except Exception as e:
logger.error(e)

View File

@@ -330,6 +330,7 @@ class IterableDatasetBuilder(csv.Csv):
super().__init__(
cache_dir=self.cache_build_dir,
dataset_name=self.dataset_name,
config_name=self.namespace,
hash=sub_dir_hash,
data_files=None, # TODO: self.meta_data_files,

View File

@@ -6,16 +6,18 @@ from datasets.download.download_config import DownloadConfig
class DataDownloadConfig(DownloadConfig):
"""
Extends `DownloadConfig` with additional attributes for data download.
"""
def __init__(self):
self.dataset_name: Optional[str] = None
self.namespace: Optional[str] = None
self.version: Optional[str] = None
self.split: Optional[Union[str, list]] = None
self.data_dir: Optional[str] = None
self.oss_config: Optional[dict] = {}
self.meta_args_map: Optional[dict] = {}
self.num_proc: int = 4
dataset_name: Optional[str] = None
namespace: Optional[str] = None
version: Optional[str] = None
split: Optional[Union[str, list]] = None
data_dir: Optional[str] = None
oss_config: Optional[dict] = {}
meta_args_map: Optional[dict] = {}
num_proc: int = 4
def copy(self) -> 'DataDownloadConfig':
return self

View File

@@ -36,6 +36,11 @@ class DataDownloadManager(DownloadManager):
return cached_path(
url_or_filename, download_config=download_config)
def _download_single(self, url_or_filename: str,
download_config: DataDownloadConfig) -> str:
# Note: _download_single function is available for datasets>=2.19.0
return self._download(url_or_filename, download_config)
class DataStreamingDownloadManager(StreamingDownloadManager):
"""The data streaming download manager."""
@@ -62,3 +67,7 @@ class DataStreamingDownloadManager(StreamingDownloadManager):
else:
return cached_path(
url_or_filename, download_config=self.download_config)
def _download_single(self, url_or_filename: str) -> str:
# Note: _download_single function is available for datasets>=2.19.0
return self._download(url_or_filename)

View File

@@ -268,17 +268,15 @@ class MsDataset:
return dataset_inst
# Load from the huggingface hub
elif hub == Hubs.huggingface:
dataset_inst = RemoteDataLoaderManager(
dataset_context_config).load_dataset(
RemoteDataLoaderType.HF_DATA_LOADER)
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
if isinstance(dataset_inst, MsDataset):
dataset_inst._dataset_context_config = dataset_context_config
if custom_cfg:
dataset_inst.to_custom_dataset(
custom_cfg=custom_cfg, **config_kwargs)
dataset_inst.is_custom = True
return dataset_inst
from datasets import load_dataset
return load_dataset(
dataset_name,
name=subset_name,
split=split,
streaming=use_streaming,
download_mode=download_mode.value,
**config_kwargs)
# Load from the modelscope hub
elif hub == Hubs.modelscope:
@@ -305,6 +303,7 @@ class MsDataset:
token=token,
streaming=use_streaming,
dataset_info_only=dataset_info_only,
trust_remote_code=True,
**config_kwargs) as dataset_res:
return dataset_res

View File

@@ -7,7 +7,7 @@ import os
import warnings
from functools import partial
from pathlib import Path
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal
from urllib.parse import urlencode
@@ -40,7 +40,7 @@ from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
_MODULE_SUPPORTS_METADATA,
_MODULE_TO_EXTENSIONS,
_PACKAGED_DATASETS_MODULES)
from datasets.utils import _datasets_server, file_utils
from datasets.utils import file_utils
from datasets.utils.file_utils import (OfflineModeIsEnabled,
_raise_if_offline_mode_is_enabled,
cached_path, is_local_path,
@@ -68,6 +68,26 @@ from modelscope.utils.logger import get_logger
logger = get_logger()
ExpandDatasetProperty_T = Literal[
'author',
'cardData',
'citation',
'createdAt',
'disabled',
'description',
'downloads',
'downloadsAllTime',
'gated',
'lastModified',
'likes',
'paperswithcode_id',
'private',
'siblings',
'sha',
'tags',
]
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
url_or_filename = str(url_or_filename)
# for temp val
@@ -97,6 +117,7 @@ def _dataset_info(
timeout: Optional[float] = None,
files_metadata: bool = False,
token: Optional[Union[bool, str]] = None,
expand: Optional[List[ExpandDatasetProperty_T]] = None,
) -> HfDatasetInfo:
"""
Get info on one specific dataset on huggingface.co.
@@ -728,19 +749,6 @@ def _download_additional_modules(
def get_module_with_script(self) -> DatasetModule:
if config.HF_DATASETS_TRUST_REMOTE_CODE and self.trust_remote_code is None:
warnings.warn(
f'The repository for {self.name} contains custom code which must be executed to correctly '
f'load the dataset. You can inspect the repository content at https://hf.co/datasets/{self.name}\n'
f'You can avoid this message in future by passing the argument `trust_remote_code=True`.\n'
f'Passing `trust_remote_code=True` will be mandatory '
f'to load this dataset from the next major release of `datasets`.',
FutureWarning,
)
# get script and other files
# local_path = self.download_loading_script()
# dataset_infos_path = self.download_dataset_infos_file()
# dataset_readme_path = self.download_dataset_readme_file()
_api = HubApi()
_dataset_name: str = self.name.split('/')[-1]
@@ -1260,8 +1268,9 @@ class DatasetsWrapperHF:
path,
download_config=download_config,
revision=dataset_info.sha).get_module()
except _datasets_server.DatasetsServerError:
pass
except Exception as e:
logger.error(e)
# Otherwise we must use the dataset script if the user trusts it
return HubDatasetModuleFactoryWithScript(
path,
@@ -1314,7 +1323,11 @@ class DatasetsWrapperHF:
def load_dataset_with_ctx(*args, **kwargs):
hf_endpoint_origin = config.HF_ENDPOINT
get_from_cache_origin = file_utils.get_from_cache
_download_origin = DownloadManager._download
# Compatible with datasets 2.18.0
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
else DownloadManager._download_single
dataset_info_origin = HfApi.dataset_info
list_repo_tree_origin = HfApi.list_repo_tree
get_paths_info_origin = HfApi.get_paths_info
@@ -1324,7 +1337,13 @@ def load_dataset_with_ctx(*args, **kwargs):
config.HF_ENDPOINT = get_endpoint()
file_utils.get_from_cache = get_from_cache_ms
# Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_ms
else:
DownloadManager._download_single = _download_ms
HfApi.dataset_info = _dataset_info
HfApi.list_repo_tree = _list_repo_tree
HfApi.get_paths_info = _get_paths_info
@@ -1338,12 +1357,16 @@ def load_dataset_with_ctx(*args, **kwargs):
finally:
config.HF_ENDPOINT = hf_endpoint_origin
file_utils.get_from_cache = get_from_cache_origin
# Compatible with datasets 2.18.0
if hasattr(DownloadManager, '_download'):
DownloadManager._download = _download_origin
else:
DownloadManager._download_single = _download_origin
HfApi.dataset_info = dataset_info_origin
HfApi.list_repo_tree = list_repo_tree_origin
HfApi.get_paths_info = get_paths_info_origin
data_files.resolve_pattern = resolve_pattern_origin
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
logger.info('Context manager of ms-dataset exited.')

View File

@@ -7,6 +7,7 @@ import os
import re
import shutil
import warnings
import inspect
from contextlib import contextmanager
from functools import partial
from pathlib import Path
@@ -41,6 +42,7 @@ def get_from_cache_ms(
ignore_url_params=False,
storage_options=None,
download_desc=None,
disable_tqdm=False,
) -> str:
"""
Given a URL, look for the corresponding file in the local cache.
@@ -209,7 +211,31 @@ def get_from_cache_ms(
if scheme == 'ftp':
ftp_get(url, temp_file)
elif scheme not in ('http', 'https'):
fsspec_get_sig = inspect.signature(fsspec_get)
if 'disable_tqdm' in fsspec_get_sig.parameters:
fsspec_get(url,
temp_file,
storage_options=storage_options,
desc=download_desc,
disable_tqdm=disable_tqdm
)
else:
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
else:
http_get_sig = inspect.signature(http_get)
if 'disable_tqdm' in http_get_sig.parameters:
http_get(
url,
temp_file=temp_file,
proxies=proxies,
resume_size=resume_size,
headers=headers,
cookies=cookies,
max_retries=max_retries,
desc=download_desc,
disable_tqdm=disable_tqdm,
)
else:
http_get(
url,

View File

@@ -1,8 +1,9 @@
addict
attrs
datasets>=2.16.0,<2.19.0
datasets>=2.18.0
einops
oss2
Pillow
python-dateutil>=2.1
scipy
# latest version has some compatible issue.

View File

@@ -1,6 +1,6 @@
addict
attrs
datasets>=2.16.0,<2.19.0
datasets>=2.18.0
einops
oss2
Pillow

View File

@@ -1,4 +1,5 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from modelscope import MsDataset
@@ -7,9 +8,6 @@ from modelscope.utils.test_utils import test_level
logger = get_logger()
# Note: MODELSCOPE_DOMAIN is set to 'test.modelscope.cn' in the environment variable
# TODO: ONLY FOR TEST ENVIRONMENT, to be replaced by the online domain
TEST_INNER_LEVEL = 1
@@ -19,32 +17,33 @@ class GeneralMsDatasetTest(unittest.TestCase):
'skip test in current test level')
def test_return_dataset_info_only(self):
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', dataset_info_only=True)
print(f'>>output of test_return_dataset_info_only:\n {ds}')
'wangxingjun778/aya_dataset_mini', dataset_info_only=True)
logger.info(f'>>output of test_return_dataset_info_only:\n {ds}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_fashion_mnist(self):
# inner means the dataset is on the test.modelscope.cn environment
ds = MsDataset.load(
'xxxxtest0004/ms_test_0308_py',
'wangxingjun778/ms_test_0308_py',
subset_name='fashion_mnist',
split='train')
print(f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
logger.info(
f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_clue(self):
ds = MsDataset.load(
'wangxingjun778test/clue', subset_name='afqmc', split='train')
print(f'>>output of test_inner_clue:\n {next(iter(ds))}')
'wangxingjun778/clue', subset_name='afqmc', split='train')
logger.info(f'>>output of test_inner_clue:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_cats_and_dogs_mini(self):
ds = MsDataset.load(
'wangxingjun778test/cats_and_dogs_mini', split='train')
print(f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
ds = MsDataset.load('wangxingjun778/cats_and_dogs_mini', split='train')
logger.info(
f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
@@ -53,14 +52,14 @@ class GeneralMsDatasetTest(unittest.TestCase):
# data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
# demographics/train-xxx-of-xxx.parquet
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', split='train')
print(f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
ds = MsDataset.load('wangxingjun778/aya_dataset_mini', split='train')
logger.info(
f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
ds = MsDataset.load(
'wangxingjun778test/aya_dataset_mini', subset_name='demographics')
'wangxingjun778/aya_dataset_mini', subset_name='demographics')
assert next(iter(ds['train']))
print(
logger.info(
f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
)
@@ -68,36 +67,46 @@ class GeneralMsDatasetTest(unittest.TestCase):
'skip test in current test level')
def test_inner_no_standard_imgs(self):
infos = MsDataset.load(
'xxxxtest0004/png_jpg_txt_test', dataset_info_only=True)
'wangxingjun778/png_jpg_txt_test', dataset_info_only=True)
assert infos['default']
ds = MsDataset.load('xxxxtest0004/png_jpg_txt_test', split='train')
print(f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
assert next(iter(ds))
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_hf_pictures(self):
ds = MsDataset.load('xxxxtest0004/hf_Pictures')
print(ds)
ds = MsDataset.load('wangxingjun778/png_jpg_txt_test', split='train')
logger.info(
f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
assert next(iter(ds))
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
def test_inner_speech_yinpin(self):
ds = MsDataset.load('xxxxtest0004/hf_lj_speech_yinpin_test')
print(ds)
ds = MsDataset.load('wangxingjun778/hf_lj_speech_yinpin_test')
logger.info(ds)
assert next(iter(ds))
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_inner_yuancheng_picture(self):
ds = MsDataset.load(
'xxxxtest0004/yuancheng_picture',
'wangxingjun778/yuancheng_picture',
subset_name='remote_images',
split='train')
print(next(iter(ds)))
logger.info(next(iter(ds)))
assert next(iter(ds))
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
'skip test in current test level')
def test_youku_mplug_dataset(self):
# To test the Youku-AliceMind dataset with new sdk version
ds = MsDataset.load(
'modelscope/Youku-AliceMind',
subset_name='classification',
split='validation', # Options: train, test, validation
use_streaming=True)
logger.info(next(iter(ds)))
data_sample = next(iter(ds))
assert data_sample['video_id'][0]
assert os.path.exists(data_sample['video_id:FILE'][0])
if __name__ == '__main__':
unittest.main()

View File

@@ -169,17 +169,6 @@ class MsDatasetTest(unittest.TestCase):
'speech_asr_aishell1_trainsets', namespace='speech_asr')
print(next(iter(ms_ds_asr['train'])))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = MsDataset.load(
'fixtures_image_utils', namespace='damotest', split='test')
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(image_path='file'))
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_img(self):
@@ -229,7 +218,7 @@ class MsDatasetTest(unittest.TestCase):
print(data_example)
assert data_example.values()
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
def test_streaming_load_img_object(self):
"""Test case for iterating PIL object."""
from PIL.PngImagePlugin import PngImageFile
@@ -238,7 +227,7 @@ class MsDatasetTest(unittest.TestCase):
subset_name='default',
namespace='huizheng',
split='train',
use_streaming=True)
use_streaming=False)
data_example = next(iter(dataset))
print(data_example)
assert data_example.values()
@@ -247,7 +236,8 @@ class MsDatasetTest(unittest.TestCase):
def test_to_ms_dataset(self):
"""Test case for converting huggingface dataset to `MsDataset` instance."""
from datasets.load import load_dataset
hf_dataset = load_dataset('beans', split='train', streaming=True)
hf_dataset = load_dataset(
'AI-Lab-Makerere/beans', split='train', streaming=True)
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
data_example = next(iter(ms_dataset))
print(data_example)

View File

@@ -17,7 +17,7 @@ class Human3DAnimationTest(unittest.TestCase):
human3d = pipeline(self.task, model=self.model_id)
input = {
'dataset_id': 'damo/3DHuman_synthetic_dataset',
'case_id': '3f2a7538253e42a8',
'case_id': '000146', # 3f2a7538253e42a8
'action_dataset': 'damo/3DHuman_action_dataset',
'action': 'SwingDancing',
'save_dir': 'outputs',