mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Upgrade datasets (#921)
* del _datasets_server import in hf_dataset_util * fix streaming for youku-mplug and adopt latest datasets * fix download config copy * update ut * add youku in test_general_datasets * update UT for general dataset * adapt to datasets version: 2.19.0 or later * add assert for youku data UT * fix disable_tqdm in some functions for 2.19.0 or later * update get_module_with_script * set trust_remote_code is True in load_dataset_with_ctx * update print info * update requirements for datasets version restriction * fix _dataset_info * add pillow * update comments * update comment * reuse _download function in DataDownloadManager * remove unused code * update test_run_modelhub in Human3DAnimationTest * set datasets>=2.18.0
This commit is contained in:
@@ -1,2 +1,2 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
from .ms_dataset import MsDataset
|
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||||
|
|||||||
@@ -149,6 +149,7 @@ class NativeIterableDataset(IterableDataset):
|
|||||||
if isinstance(ex_cache_path, str):
|
if isinstance(ex_cache_path, str):
|
||||||
ex_cache_path = [ex_cache_path]
|
ex_cache_path = [ex_cache_path]
|
||||||
ret[k] = ex_cache_path
|
ret[k] = ex_cache_path
|
||||||
|
ret[k.strip(':FILE')] = v
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(e)
|
logger.error(e)
|
||||||
|
|||||||
@@ -330,6 +330,7 @@ class IterableDatasetBuilder(csv.Csv):
|
|||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
cache_dir=self.cache_build_dir,
|
cache_dir=self.cache_build_dir,
|
||||||
|
dataset_name=self.dataset_name,
|
||||||
config_name=self.namespace,
|
config_name=self.namespace,
|
||||||
hash=sub_dir_hash,
|
hash=sub_dir_hash,
|
||||||
data_files=None, # TODO: self.meta_data_files,
|
data_files=None, # TODO: self.meta_data_files,
|
||||||
|
|||||||
@@ -6,16 +6,18 @@ from datasets.download.download_config import DownloadConfig
|
|||||||
|
|
||||||
|
|
||||||
class DataDownloadConfig(DownloadConfig):
|
class DataDownloadConfig(DownloadConfig):
|
||||||
|
"""
|
||||||
|
Extends `DownloadConfig` with additional attributes for data download.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
dataset_name: Optional[str] = None
|
||||||
self.dataset_name: Optional[str] = None
|
namespace: Optional[str] = None
|
||||||
self.namespace: Optional[str] = None
|
version: Optional[str] = None
|
||||||
self.version: Optional[str] = None
|
split: Optional[Union[str, list]] = None
|
||||||
self.split: Optional[Union[str, list]] = None
|
data_dir: Optional[str] = None
|
||||||
self.data_dir: Optional[str] = None
|
oss_config: Optional[dict] = {}
|
||||||
self.oss_config: Optional[dict] = {}
|
meta_args_map: Optional[dict] = {}
|
||||||
self.meta_args_map: Optional[dict] = {}
|
num_proc: int = 4
|
||||||
self.num_proc: int = 4
|
|
||||||
|
|
||||||
def copy(self) -> 'DataDownloadConfig':
|
def copy(self) -> 'DataDownloadConfig':
|
||||||
return self
|
return self
|
||||||
|
|||||||
@@ -36,6 +36,11 @@ class DataDownloadManager(DownloadManager):
|
|||||||
return cached_path(
|
return cached_path(
|
||||||
url_or_filename, download_config=download_config)
|
url_or_filename, download_config=download_config)
|
||||||
|
|
||||||
|
def _download_single(self, url_or_filename: str,
|
||||||
|
download_config: DataDownloadConfig) -> str:
|
||||||
|
# Note: _download_single function is available for datasets>=2.19.0
|
||||||
|
return self._download(url_or_filename, download_config)
|
||||||
|
|
||||||
|
|
||||||
class DataStreamingDownloadManager(StreamingDownloadManager):
|
class DataStreamingDownloadManager(StreamingDownloadManager):
|
||||||
"""The data streaming download manager."""
|
"""The data streaming download manager."""
|
||||||
@@ -62,3 +67,7 @@ class DataStreamingDownloadManager(StreamingDownloadManager):
|
|||||||
else:
|
else:
|
||||||
return cached_path(
|
return cached_path(
|
||||||
url_or_filename, download_config=self.download_config)
|
url_or_filename, download_config=self.download_config)
|
||||||
|
|
||||||
|
def _download_single(self, url_or_filename: str) -> str:
|
||||||
|
# Note: _download_single function is available for datasets>=2.19.0
|
||||||
|
return self._download(url_or_filename)
|
||||||
|
|||||||
@@ -268,17 +268,15 @@ class MsDataset:
|
|||||||
return dataset_inst
|
return dataset_inst
|
||||||
# Load from the huggingface hub
|
# Load from the huggingface hub
|
||||||
elif hub == Hubs.huggingface:
|
elif hub == Hubs.huggingface:
|
||||||
dataset_inst = RemoteDataLoaderManager(
|
from datasets import load_dataset
|
||||||
dataset_context_config).load_dataset(
|
return load_dataset(
|
||||||
RemoteDataLoaderType.HF_DATA_LOADER)
|
dataset_name,
|
||||||
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
|
name=subset_name,
|
||||||
if isinstance(dataset_inst, MsDataset):
|
split=split,
|
||||||
dataset_inst._dataset_context_config = dataset_context_config
|
streaming=use_streaming,
|
||||||
if custom_cfg:
|
download_mode=download_mode.value,
|
||||||
dataset_inst.to_custom_dataset(
|
**config_kwargs)
|
||||||
custom_cfg=custom_cfg, **config_kwargs)
|
|
||||||
dataset_inst.is_custom = True
|
|
||||||
return dataset_inst
|
|
||||||
# Load from the modelscope hub
|
# Load from the modelscope hub
|
||||||
elif hub == Hubs.modelscope:
|
elif hub == Hubs.modelscope:
|
||||||
|
|
||||||
@@ -305,6 +303,7 @@ class MsDataset:
|
|||||||
token=token,
|
token=token,
|
||||||
streaming=use_streaming,
|
streaming=use_streaming,
|
||||||
dataset_info_only=dataset_info_only,
|
dataset_info_only=dataset_info_only,
|
||||||
|
trust_remote_code=True,
|
||||||
**config_kwargs) as dataset_res:
|
**config_kwargs) as dataset_res:
|
||||||
|
|
||||||
return dataset_res
|
return dataset_res
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import os
|
|||||||
import warnings
|
import warnings
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple
|
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal
|
||||||
|
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
@@ -40,7 +40,7 @@ from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
|
|||||||
_MODULE_SUPPORTS_METADATA,
|
_MODULE_SUPPORTS_METADATA,
|
||||||
_MODULE_TO_EXTENSIONS,
|
_MODULE_TO_EXTENSIONS,
|
||||||
_PACKAGED_DATASETS_MODULES)
|
_PACKAGED_DATASETS_MODULES)
|
||||||
from datasets.utils import _datasets_server, file_utils
|
from datasets.utils import file_utils
|
||||||
from datasets.utils.file_utils import (OfflineModeIsEnabled,
|
from datasets.utils.file_utils import (OfflineModeIsEnabled,
|
||||||
_raise_if_offline_mode_is_enabled,
|
_raise_if_offline_mode_is_enabled,
|
||||||
cached_path, is_local_path,
|
cached_path, is_local_path,
|
||||||
@@ -68,6 +68,26 @@ from modelscope.utils.logger import get_logger
|
|||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
|
|
||||||
|
ExpandDatasetProperty_T = Literal[
|
||||||
|
'author',
|
||||||
|
'cardData',
|
||||||
|
'citation',
|
||||||
|
'createdAt',
|
||||||
|
'disabled',
|
||||||
|
'description',
|
||||||
|
'downloads',
|
||||||
|
'downloadsAllTime',
|
||||||
|
'gated',
|
||||||
|
'lastModified',
|
||||||
|
'likes',
|
||||||
|
'paperswithcode_id',
|
||||||
|
'private',
|
||||||
|
'siblings',
|
||||||
|
'sha',
|
||||||
|
'tags',
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
||||||
url_or_filename = str(url_or_filename)
|
url_or_filename = str(url_or_filename)
|
||||||
# for temp val
|
# for temp val
|
||||||
@@ -97,6 +117,7 @@ def _dataset_info(
|
|||||||
timeout: Optional[float] = None,
|
timeout: Optional[float] = None,
|
||||||
files_metadata: bool = False,
|
files_metadata: bool = False,
|
||||||
token: Optional[Union[bool, str]] = None,
|
token: Optional[Union[bool, str]] = None,
|
||||||
|
expand: Optional[List[ExpandDatasetProperty_T]] = None,
|
||||||
) -> HfDatasetInfo:
|
) -> HfDatasetInfo:
|
||||||
"""
|
"""
|
||||||
Get info on one specific dataset on huggingface.co.
|
Get info on one specific dataset on huggingface.co.
|
||||||
@@ -728,19 +749,6 @@ def _download_additional_modules(
|
|||||||
|
|
||||||
|
|
||||||
def get_module_with_script(self) -> DatasetModule:
|
def get_module_with_script(self) -> DatasetModule:
|
||||||
if config.HF_DATASETS_TRUST_REMOTE_CODE and self.trust_remote_code is None:
|
|
||||||
warnings.warn(
|
|
||||||
f'The repository for {self.name} contains custom code which must be executed to correctly '
|
|
||||||
f'load the dataset. You can inspect the repository content at https://hf.co/datasets/{self.name}\n'
|
|
||||||
f'You can avoid this message in future by passing the argument `trust_remote_code=True`.\n'
|
|
||||||
f'Passing `trust_remote_code=True` will be mandatory '
|
|
||||||
f'to load this dataset from the next major release of `datasets`.',
|
|
||||||
FutureWarning,
|
|
||||||
)
|
|
||||||
# get script and other files
|
|
||||||
# local_path = self.download_loading_script()
|
|
||||||
# dataset_infos_path = self.download_dataset_infos_file()
|
|
||||||
# dataset_readme_path = self.download_dataset_readme_file()
|
|
||||||
|
|
||||||
_api = HubApi()
|
_api = HubApi()
|
||||||
_dataset_name: str = self.name.split('/')[-1]
|
_dataset_name: str = self.name.split('/')[-1]
|
||||||
@@ -1260,8 +1268,9 @@ class DatasetsWrapperHF:
|
|||||||
path,
|
path,
|
||||||
download_config=download_config,
|
download_config=download_config,
|
||||||
revision=dataset_info.sha).get_module()
|
revision=dataset_info.sha).get_module()
|
||||||
except _datasets_server.DatasetsServerError:
|
except Exception as e:
|
||||||
pass
|
logger.error(e)
|
||||||
|
|
||||||
# Otherwise we must use the dataset script if the user trusts it
|
# Otherwise we must use the dataset script if the user trusts it
|
||||||
return HubDatasetModuleFactoryWithScript(
|
return HubDatasetModuleFactoryWithScript(
|
||||||
path,
|
path,
|
||||||
@@ -1314,7 +1323,11 @@ class DatasetsWrapperHF:
|
|||||||
def load_dataset_with_ctx(*args, **kwargs):
|
def load_dataset_with_ctx(*args, **kwargs):
|
||||||
hf_endpoint_origin = config.HF_ENDPOINT
|
hf_endpoint_origin = config.HF_ENDPOINT
|
||||||
get_from_cache_origin = file_utils.get_from_cache
|
get_from_cache_origin = file_utils.get_from_cache
|
||||||
_download_origin = DownloadManager._download
|
|
||||||
|
# Compatible with datasets 2.18.0
|
||||||
|
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
|
||||||
|
else DownloadManager._download_single
|
||||||
|
|
||||||
dataset_info_origin = HfApi.dataset_info
|
dataset_info_origin = HfApi.dataset_info
|
||||||
list_repo_tree_origin = HfApi.list_repo_tree
|
list_repo_tree_origin = HfApi.list_repo_tree
|
||||||
get_paths_info_origin = HfApi.get_paths_info
|
get_paths_info_origin = HfApi.get_paths_info
|
||||||
@@ -1324,7 +1337,13 @@ def load_dataset_with_ctx(*args, **kwargs):
|
|||||||
|
|
||||||
config.HF_ENDPOINT = get_endpoint()
|
config.HF_ENDPOINT = get_endpoint()
|
||||||
file_utils.get_from_cache = get_from_cache_ms
|
file_utils.get_from_cache = get_from_cache_ms
|
||||||
|
|
||||||
|
# Compatible with datasets 2.18.0
|
||||||
|
if hasattr(DownloadManager, '_download'):
|
||||||
DownloadManager._download = _download_ms
|
DownloadManager._download = _download_ms
|
||||||
|
else:
|
||||||
|
DownloadManager._download_single = _download_ms
|
||||||
|
|
||||||
HfApi.dataset_info = _dataset_info
|
HfApi.dataset_info = _dataset_info
|
||||||
HfApi.list_repo_tree = _list_repo_tree
|
HfApi.list_repo_tree = _list_repo_tree
|
||||||
HfApi.get_paths_info = _get_paths_info
|
HfApi.get_paths_info = _get_paths_info
|
||||||
@@ -1338,12 +1357,16 @@ def load_dataset_with_ctx(*args, **kwargs):
|
|||||||
finally:
|
finally:
|
||||||
config.HF_ENDPOINT = hf_endpoint_origin
|
config.HF_ENDPOINT = hf_endpoint_origin
|
||||||
file_utils.get_from_cache = get_from_cache_origin
|
file_utils.get_from_cache = get_from_cache_origin
|
||||||
|
|
||||||
|
# Compatible with datasets 2.18.0
|
||||||
|
if hasattr(DownloadManager, '_download'):
|
||||||
DownloadManager._download = _download_origin
|
DownloadManager._download = _download_origin
|
||||||
|
else:
|
||||||
|
DownloadManager._download_single = _download_origin
|
||||||
|
|
||||||
HfApi.dataset_info = dataset_info_origin
|
HfApi.dataset_info = dataset_info_origin
|
||||||
HfApi.list_repo_tree = list_repo_tree_origin
|
HfApi.list_repo_tree = list_repo_tree_origin
|
||||||
HfApi.get_paths_info = get_paths_info_origin
|
HfApi.get_paths_info = get_paths_info_origin
|
||||||
data_files.resolve_pattern = resolve_pattern_origin
|
data_files.resolve_pattern = resolve_pattern_origin
|
||||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
||||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||||
|
|
||||||
logger.info('Context manager of ms-dataset exited.')
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import os
|
|||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import warnings
|
import warnings
|
||||||
|
import inspect
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -41,6 +42,7 @@ def get_from_cache_ms(
|
|||||||
ignore_url_params=False,
|
ignore_url_params=False,
|
||||||
storage_options=None,
|
storage_options=None,
|
||||||
download_desc=None,
|
download_desc=None,
|
||||||
|
disable_tqdm=False,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Given a URL, look for the corresponding file in the local cache.
|
Given a URL, look for the corresponding file in the local cache.
|
||||||
@@ -209,7 +211,31 @@ def get_from_cache_ms(
|
|||||||
if scheme == 'ftp':
|
if scheme == 'ftp':
|
||||||
ftp_get(url, temp_file)
|
ftp_get(url, temp_file)
|
||||||
elif scheme not in ('http', 'https'):
|
elif scheme not in ('http', 'https'):
|
||||||
|
fsspec_get_sig = inspect.signature(fsspec_get)
|
||||||
|
if 'disable_tqdm' in fsspec_get_sig.parameters:
|
||||||
|
fsspec_get(url,
|
||||||
|
temp_file,
|
||||||
|
storage_options=storage_options,
|
||||||
|
desc=download_desc,
|
||||||
|
disable_tqdm=disable_tqdm
|
||||||
|
)
|
||||||
|
else:
|
||||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||||
|
else:
|
||||||
|
http_get_sig = inspect.signature(http_get)
|
||||||
|
|
||||||
|
if 'disable_tqdm' in http_get_sig.parameters:
|
||||||
|
http_get(
|
||||||
|
url,
|
||||||
|
temp_file=temp_file,
|
||||||
|
proxies=proxies,
|
||||||
|
resume_size=resume_size,
|
||||||
|
headers=headers,
|
||||||
|
cookies=cookies,
|
||||||
|
max_retries=max_retries,
|
||||||
|
desc=download_desc,
|
||||||
|
disable_tqdm=disable_tqdm,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
http_get(
|
http_get(
|
||||||
url,
|
url,
|
||||||
|
|||||||
@@ -1,8 +1,9 @@
|
|||||||
addict
|
addict
|
||||||
attrs
|
attrs
|
||||||
datasets>=2.16.0,<2.19.0
|
datasets>=2.18.0
|
||||||
einops
|
einops
|
||||||
oss2
|
oss2
|
||||||
|
Pillow
|
||||||
python-dateutil>=2.1
|
python-dateutil>=2.1
|
||||||
scipy
|
scipy
|
||||||
# latest version has some compatible issue.
|
# latest version has some compatible issue.
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
addict
|
addict
|
||||||
attrs
|
attrs
|
||||||
datasets>=2.16.0,<2.19.0
|
datasets>=2.18.0
|
||||||
einops
|
einops
|
||||||
oss2
|
oss2
|
||||||
Pillow
|
Pillow
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||||
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from modelscope import MsDataset
|
from modelscope import MsDataset
|
||||||
@@ -7,9 +8,6 @@ from modelscope.utils.test_utils import test_level
|
|||||||
|
|
||||||
logger = get_logger()
|
logger = get_logger()
|
||||||
|
|
||||||
# Note: MODELSCOPE_DOMAIN is set to 'test.modelscope.cn' in the environment variable
|
|
||||||
# TODO: ONLY FOR TEST ENVIRONMENT, to be replaced by the online domain
|
|
||||||
|
|
||||||
TEST_INNER_LEVEL = 1
|
TEST_INNER_LEVEL = 1
|
||||||
|
|
||||||
|
|
||||||
@@ -19,32 +17,33 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
|||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_return_dataset_info_only(self):
|
def test_return_dataset_info_only(self):
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load(
|
||||||
'wangxingjun778test/aya_dataset_mini', dataset_info_only=True)
|
'wangxingjun778/aya_dataset_mini', dataset_info_only=True)
|
||||||
print(f'>>output of test_return_dataset_info_only:\n {ds}')
|
logger.info(f'>>output of test_return_dataset_info_only:\n {ds}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_inner_fashion_mnist(self):
|
def test_inner_fashion_mnist(self):
|
||||||
# inner means the dataset is on the test.modelscope.cn environment
|
# inner means the dataset is on the test.modelscope.cn environment
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load(
|
||||||
'xxxxtest0004/ms_test_0308_py',
|
'wangxingjun778/ms_test_0308_py',
|
||||||
subset_name='fashion_mnist',
|
subset_name='fashion_mnist',
|
||||||
split='train')
|
split='train')
|
||||||
print(f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
|
logger.info(
|
||||||
|
f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_inner_clue(self):
|
def test_inner_clue(self):
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load(
|
||||||
'wangxingjun778test/clue', subset_name='afqmc', split='train')
|
'wangxingjun778/clue', subset_name='afqmc', split='train')
|
||||||
print(f'>>output of test_inner_clue:\n {next(iter(ds))}')
|
logger.info(f'>>output of test_inner_clue:\n {next(iter(ds))}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_inner_cats_and_dogs_mini(self):
|
def test_inner_cats_and_dogs_mini(self):
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load('wangxingjun778/cats_and_dogs_mini', split='train')
|
||||||
'wangxingjun778test/cats_and_dogs_mini', split='train')
|
logger.info(
|
||||||
print(f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
|
f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
@@ -53,14 +52,14 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
|||||||
# data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
|
# data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
|
||||||
# demographics/train-xxx-of-xxx.parquet
|
# demographics/train-xxx-of-xxx.parquet
|
||||||
|
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load('wangxingjun778/aya_dataset_mini', split='train')
|
||||||
'wangxingjun778test/aya_dataset_mini', split='train')
|
logger.info(
|
||||||
print(f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
|
f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
|
||||||
|
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load(
|
||||||
'wangxingjun778test/aya_dataset_mini', subset_name='demographics')
|
'wangxingjun778/aya_dataset_mini', subset_name='demographics')
|
||||||
assert next(iter(ds['train']))
|
assert next(iter(ds['train']))
|
||||||
print(
|
logger.info(
|
||||||
f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
|
f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -68,36 +67,46 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
|||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_inner_no_standard_imgs(self):
|
def test_inner_no_standard_imgs(self):
|
||||||
infos = MsDataset.load(
|
infos = MsDataset.load(
|
||||||
'xxxxtest0004/png_jpg_txt_test', dataset_info_only=True)
|
'wangxingjun778/png_jpg_txt_test', dataset_info_only=True)
|
||||||
assert infos['default']
|
assert infos['default']
|
||||||
|
|
||||||
ds = MsDataset.load('xxxxtest0004/png_jpg_txt_test', split='train')
|
ds = MsDataset.load('wangxingjun778/png_jpg_txt_test', split='train')
|
||||||
print(f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
|
logger.info(
|
||||||
assert next(iter(ds))
|
f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
|
||||||
'skip test in current test level')
|
|
||||||
def test_inner_hf_pictures(self):
|
|
||||||
ds = MsDataset.load('xxxxtest0004/hf_Pictures')
|
|
||||||
print(ds)
|
|
||||||
assert next(iter(ds))
|
assert next(iter(ds))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||||
def test_inner_speech_yinpin(self):
|
def test_inner_speech_yinpin(self):
|
||||||
ds = MsDataset.load('xxxxtest0004/hf_lj_speech_yinpin_test')
|
ds = MsDataset.load('wangxingjun778/hf_lj_speech_yinpin_test')
|
||||||
print(ds)
|
logger.info(ds)
|
||||||
assert next(iter(ds))
|
assert next(iter(ds))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
'skip test in current test level')
|
'skip test in current test level')
|
||||||
def test_inner_yuancheng_picture(self):
|
def test_inner_yuancheng_picture(self):
|
||||||
ds = MsDataset.load(
|
ds = MsDataset.load(
|
||||||
'xxxxtest0004/yuancheng_picture',
|
'wangxingjun778/yuancheng_picture',
|
||||||
subset_name='remote_images',
|
subset_name='remote_images',
|
||||||
split='train')
|
split='train')
|
||||||
print(next(iter(ds)))
|
logger.info(next(iter(ds)))
|
||||||
assert next(iter(ds))
|
assert next(iter(ds))
|
||||||
|
|
||||||
|
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||||
|
'skip test in current test level')
|
||||||
|
def test_youku_mplug_dataset(self):
|
||||||
|
# To test the Youku-AliceMind dataset with new sdk version
|
||||||
|
ds = MsDataset.load(
|
||||||
|
'modelscope/Youku-AliceMind',
|
||||||
|
subset_name='classification',
|
||||||
|
split='validation', # Options: train, test, validation
|
||||||
|
use_streaming=True)
|
||||||
|
|
||||||
|
logger.info(next(iter(ds)))
|
||||||
|
data_sample = next(iter(ds))
|
||||||
|
|
||||||
|
assert data_sample['video_id'][0]
|
||||||
|
assert os.path.exists(data_sample['video_id:FILE'][0])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -169,17 +169,6 @@ class MsDatasetTest(unittest.TestCase):
|
|||||||
'speech_asr_aishell1_trainsets', namespace='speech_asr')
|
'speech_asr_aishell1_trainsets', namespace='speech_asr')
|
||||||
print(next(iter(ms_ds_asr['train'])))
|
print(next(iter(ms_ds_asr['train'])))
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
|
||||||
@require_torch
|
|
||||||
def test_to_torch_dataset_img(self):
|
|
||||||
ms_image_train = MsDataset.load(
|
|
||||||
'fixtures_image_utils', namespace='damotest', split='test')
|
|
||||||
pt_dataset = ms_image_train.to_torch_dataset(
|
|
||||||
preprocessors=ImgPreprocessor(image_path='file'))
|
|
||||||
import torch
|
|
||||||
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
|
|
||||||
print(next(iter(dataloader)))
|
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||||
@require_tf
|
@require_tf
|
||||||
def test_to_tf_dataset_img(self):
|
def test_to_tf_dataset_img(self):
|
||||||
@@ -229,7 +218,7 @@ class MsDatasetTest(unittest.TestCase):
|
|||||||
print(data_example)
|
print(data_example)
|
||||||
assert data_example.values()
|
assert data_example.values()
|
||||||
|
|
||||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||||
def test_streaming_load_img_object(self):
|
def test_streaming_load_img_object(self):
|
||||||
"""Test case for iterating PIL object."""
|
"""Test case for iterating PIL object."""
|
||||||
from PIL.PngImagePlugin import PngImageFile
|
from PIL.PngImagePlugin import PngImageFile
|
||||||
@@ -238,7 +227,7 @@ class MsDatasetTest(unittest.TestCase):
|
|||||||
subset_name='default',
|
subset_name='default',
|
||||||
namespace='huizheng',
|
namespace='huizheng',
|
||||||
split='train',
|
split='train',
|
||||||
use_streaming=True)
|
use_streaming=False)
|
||||||
data_example = next(iter(dataset))
|
data_example = next(iter(dataset))
|
||||||
print(data_example)
|
print(data_example)
|
||||||
assert data_example.values()
|
assert data_example.values()
|
||||||
@@ -247,7 +236,8 @@ class MsDatasetTest(unittest.TestCase):
|
|||||||
def test_to_ms_dataset(self):
|
def test_to_ms_dataset(self):
|
||||||
"""Test case for converting huggingface dataset to `MsDataset` instance."""
|
"""Test case for converting huggingface dataset to `MsDataset` instance."""
|
||||||
from datasets.load import load_dataset
|
from datasets.load import load_dataset
|
||||||
hf_dataset = load_dataset('beans', split='train', streaming=True)
|
hf_dataset = load_dataset(
|
||||||
|
'AI-Lab-Makerere/beans', split='train', streaming=True)
|
||||||
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
|
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
|
||||||
data_example = next(iter(ms_dataset))
|
data_example = next(iter(ms_dataset))
|
||||||
print(data_example)
|
print(data_example)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class Human3DAnimationTest(unittest.TestCase):
|
|||||||
human3d = pipeline(self.task, model=self.model_id)
|
human3d = pipeline(self.task, model=self.model_id)
|
||||||
input = {
|
input = {
|
||||||
'dataset_id': 'damo/3DHuman_synthetic_dataset',
|
'dataset_id': 'damo/3DHuman_synthetic_dataset',
|
||||||
'case_id': '3f2a7538253e42a8',
|
'case_id': '000146', # 3f2a7538253e42a8
|
||||||
'action_dataset': 'damo/3DHuman_action_dataset',
|
'action_dataset': 'damo/3DHuman_action_dataset',
|
||||||
'action': 'SwingDancing',
|
'action': 'SwingDancing',
|
||||||
'save_dir': 'outputs',
|
'save_dir': 'outputs',
|
||||||
|
|||||||
Reference in New Issue
Block a user