mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 00:07:42 +01:00
Upgrade datasets (#921)
* del _datasets_server import in hf_dataset_util * fix streaming for youku-mplug and adopt latest datasets * fix download config copy * update ut * add youku in test_general_datasets * update UT for general dataset * adapt to datasets version: 2.19.0 or later * add assert for youku data UT * fix disable_tqdm in some functions for 2.19.0 or later * update get_module_with_script * set trust_remote_code is True in load_dataset_with_ctx * update print info * update requirements for datasets version restriction * fix _dataset_info * add pillow * update comments * update comment * reuse _download function in DataDownloadManager * remove unused code * update test_run_modelhub in Human3DAnimationTest * set datasets>=2.18.0
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .ms_dataset import MsDataset
|
||||
from modelscope.msdatasets.ms_dataset import MsDataset
|
||||
|
||||
@@ -149,6 +149,7 @@ class NativeIterableDataset(IterableDataset):
|
||||
if isinstance(ex_cache_path, str):
|
||||
ex_cache_path = [ex_cache_path]
|
||||
ret[k] = ex_cache_path
|
||||
ret[k.strip(':FILE')] = v
|
||||
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
@@ -330,6 +330,7 @@ class IterableDatasetBuilder(csv.Csv):
|
||||
|
||||
super().__init__(
|
||||
cache_dir=self.cache_build_dir,
|
||||
dataset_name=self.dataset_name,
|
||||
config_name=self.namespace,
|
||||
hash=sub_dir_hash,
|
||||
data_files=None, # TODO: self.meta_data_files,
|
||||
|
||||
@@ -6,16 +6,18 @@ from datasets.download.download_config import DownloadConfig
|
||||
|
||||
|
||||
class DataDownloadConfig(DownloadConfig):
|
||||
"""
|
||||
Extends `DownloadConfig` with additional attributes for data download.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.dataset_name: Optional[str] = None
|
||||
self.namespace: Optional[str] = None
|
||||
self.version: Optional[str] = None
|
||||
self.split: Optional[Union[str, list]] = None
|
||||
self.data_dir: Optional[str] = None
|
||||
self.oss_config: Optional[dict] = {}
|
||||
self.meta_args_map: Optional[dict] = {}
|
||||
self.num_proc: int = 4
|
||||
dataset_name: Optional[str] = None
|
||||
namespace: Optional[str] = None
|
||||
version: Optional[str] = None
|
||||
split: Optional[Union[str, list]] = None
|
||||
data_dir: Optional[str] = None
|
||||
oss_config: Optional[dict] = {}
|
||||
meta_args_map: Optional[dict] = {}
|
||||
num_proc: int = 4
|
||||
|
||||
def copy(self) -> 'DataDownloadConfig':
|
||||
return self
|
||||
|
||||
@@ -36,6 +36,11 @@ class DataDownloadManager(DownloadManager):
|
||||
return cached_path(
|
||||
url_or_filename, download_config=download_config)
|
||||
|
||||
def _download_single(self, url_or_filename: str,
|
||||
download_config: DataDownloadConfig) -> str:
|
||||
# Note: _download_single function is available for datasets>=2.19.0
|
||||
return self._download(url_or_filename, download_config)
|
||||
|
||||
|
||||
class DataStreamingDownloadManager(StreamingDownloadManager):
|
||||
"""The data streaming download manager."""
|
||||
@@ -62,3 +67,7 @@ class DataStreamingDownloadManager(StreamingDownloadManager):
|
||||
else:
|
||||
return cached_path(
|
||||
url_or_filename, download_config=self.download_config)
|
||||
|
||||
def _download_single(self, url_or_filename: str) -> str:
|
||||
# Note: _download_single function is available for datasets>=2.19.0
|
||||
return self._download(url_or_filename)
|
||||
|
||||
@@ -268,17 +268,15 @@ class MsDataset:
|
||||
return dataset_inst
|
||||
# Load from the huggingface hub
|
||||
elif hub == Hubs.huggingface:
|
||||
dataset_inst = RemoteDataLoaderManager(
|
||||
dataset_context_config).load_dataset(
|
||||
RemoteDataLoaderType.HF_DATA_LOADER)
|
||||
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
|
||||
if isinstance(dataset_inst, MsDataset):
|
||||
dataset_inst._dataset_context_config = dataset_context_config
|
||||
if custom_cfg:
|
||||
dataset_inst.to_custom_dataset(
|
||||
custom_cfg=custom_cfg, **config_kwargs)
|
||||
dataset_inst.is_custom = True
|
||||
return dataset_inst
|
||||
from datasets import load_dataset
|
||||
return load_dataset(
|
||||
dataset_name,
|
||||
name=subset_name,
|
||||
split=split,
|
||||
streaming=use_streaming,
|
||||
download_mode=download_mode.value,
|
||||
**config_kwargs)
|
||||
|
||||
# Load from the modelscope hub
|
||||
elif hub == Hubs.modelscope:
|
||||
|
||||
@@ -305,6 +303,7 @@ class MsDataset:
|
||||
token=token,
|
||||
streaming=use_streaming,
|
||||
dataset_info_only=dataset_info_only,
|
||||
trust_remote_code=True,
|
||||
**config_kwargs) as dataset_res:
|
||||
|
||||
return dataset_res
|
||||
|
||||
@@ -7,7 +7,7 @@ import os
|
||||
import warnings
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple
|
||||
from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Union, Tuple, Literal
|
||||
|
||||
from urllib.parse import urlencode
|
||||
|
||||
@@ -40,7 +40,7 @@ from datasets.packaged_modules import (_EXTENSION_TO_MODULE,
|
||||
_MODULE_SUPPORTS_METADATA,
|
||||
_MODULE_TO_EXTENSIONS,
|
||||
_PACKAGED_DATASETS_MODULES)
|
||||
from datasets.utils import _datasets_server, file_utils
|
||||
from datasets.utils import file_utils
|
||||
from datasets.utils.file_utils import (OfflineModeIsEnabled,
|
||||
_raise_if_offline_mode_is_enabled,
|
||||
cached_path, is_local_path,
|
||||
@@ -68,6 +68,26 @@ from modelscope.utils.logger import get_logger
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
ExpandDatasetProperty_T = Literal[
|
||||
'author',
|
||||
'cardData',
|
||||
'citation',
|
||||
'createdAt',
|
||||
'disabled',
|
||||
'description',
|
||||
'downloads',
|
||||
'downloadsAllTime',
|
||||
'gated',
|
||||
'lastModified',
|
||||
'likes',
|
||||
'paperswithcode_id',
|
||||
'private',
|
||||
'siblings',
|
||||
'sha',
|
||||
'tags',
|
||||
]
|
||||
|
||||
|
||||
def _download_ms(self, url_or_filename: str, download_config: DownloadConfig) -> str:
|
||||
url_or_filename = str(url_or_filename)
|
||||
# for temp val
|
||||
@@ -97,6 +117,7 @@ def _dataset_info(
|
||||
timeout: Optional[float] = None,
|
||||
files_metadata: bool = False,
|
||||
token: Optional[Union[bool, str]] = None,
|
||||
expand: Optional[List[ExpandDatasetProperty_T]] = None,
|
||||
) -> HfDatasetInfo:
|
||||
"""
|
||||
Get info on one specific dataset on huggingface.co.
|
||||
@@ -728,19 +749,6 @@ def _download_additional_modules(
|
||||
|
||||
|
||||
def get_module_with_script(self) -> DatasetModule:
|
||||
if config.HF_DATASETS_TRUST_REMOTE_CODE and self.trust_remote_code is None:
|
||||
warnings.warn(
|
||||
f'The repository for {self.name} contains custom code which must be executed to correctly '
|
||||
f'load the dataset. You can inspect the repository content at https://hf.co/datasets/{self.name}\n'
|
||||
f'You can avoid this message in future by passing the argument `trust_remote_code=True`.\n'
|
||||
f'Passing `trust_remote_code=True` will be mandatory '
|
||||
f'to load this dataset from the next major release of `datasets`.',
|
||||
FutureWarning,
|
||||
)
|
||||
# get script and other files
|
||||
# local_path = self.download_loading_script()
|
||||
# dataset_infos_path = self.download_dataset_infos_file()
|
||||
# dataset_readme_path = self.download_dataset_readme_file()
|
||||
|
||||
_api = HubApi()
|
||||
_dataset_name: str = self.name.split('/')[-1]
|
||||
@@ -1260,8 +1268,9 @@ class DatasetsWrapperHF:
|
||||
path,
|
||||
download_config=download_config,
|
||||
revision=dataset_info.sha).get_module()
|
||||
except _datasets_server.DatasetsServerError:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
|
||||
# Otherwise we must use the dataset script if the user trusts it
|
||||
return HubDatasetModuleFactoryWithScript(
|
||||
path,
|
||||
@@ -1314,7 +1323,11 @@ class DatasetsWrapperHF:
|
||||
def load_dataset_with_ctx(*args, **kwargs):
|
||||
hf_endpoint_origin = config.HF_ENDPOINT
|
||||
get_from_cache_origin = file_utils.get_from_cache
|
||||
_download_origin = DownloadManager._download
|
||||
|
||||
# Compatible with datasets 2.18.0
|
||||
_download_origin = DownloadManager._download if hasattr(DownloadManager, '_download') \
|
||||
else DownloadManager._download_single
|
||||
|
||||
dataset_info_origin = HfApi.dataset_info
|
||||
list_repo_tree_origin = HfApi.list_repo_tree
|
||||
get_paths_info_origin = HfApi.get_paths_info
|
||||
@@ -1324,7 +1337,13 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
|
||||
config.HF_ENDPOINT = get_endpoint()
|
||||
file_utils.get_from_cache = get_from_cache_ms
|
||||
DownloadManager._download = _download_ms
|
||||
|
||||
# Compatible with datasets 2.18.0
|
||||
if hasattr(DownloadManager, '_download'):
|
||||
DownloadManager._download = _download_ms
|
||||
else:
|
||||
DownloadManager._download_single = _download_ms
|
||||
|
||||
HfApi.dataset_info = _dataset_info
|
||||
HfApi.list_repo_tree = _list_repo_tree
|
||||
HfApi.get_paths_info = _get_paths_info
|
||||
@@ -1338,12 +1357,16 @@ def load_dataset_with_ctx(*args, **kwargs):
|
||||
finally:
|
||||
config.HF_ENDPOINT = hf_endpoint_origin
|
||||
file_utils.get_from_cache = get_from_cache_origin
|
||||
DownloadManager._download = _download_origin
|
||||
|
||||
# Compatible with datasets 2.18.0
|
||||
if hasattr(DownloadManager, '_download'):
|
||||
DownloadManager._download = _download_origin
|
||||
else:
|
||||
DownloadManager._download_single = _download_origin
|
||||
|
||||
HfApi.dataset_info = dataset_info_origin
|
||||
HfApi.list_repo_tree = list_repo_tree_origin
|
||||
HfApi.get_paths_info = get_paths_info_origin
|
||||
data_files.resolve_pattern = resolve_pattern_origin
|
||||
HubDatasetModuleFactoryWithoutScript.get_module = get_module_without_script_origin
|
||||
HubDatasetModuleFactoryWithScript.get_module = get_module_with_script_origin
|
||||
|
||||
logger.info('Context manager of ms-dataset exited.')
|
||||
|
||||
@@ -7,6 +7,7 @@ import os
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
import inspect
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -41,6 +42,7 @@ def get_from_cache_ms(
|
||||
ignore_url_params=False,
|
||||
storage_options=None,
|
||||
download_desc=None,
|
||||
disable_tqdm=False,
|
||||
) -> str:
|
||||
"""
|
||||
Given a URL, look for the corresponding file in the local cache.
|
||||
@@ -209,18 +211,42 @@ def get_from_cache_ms(
|
||||
if scheme == 'ftp':
|
||||
ftp_get(url, temp_file)
|
||||
elif scheme not in ('http', 'https'):
|
||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||
fsspec_get_sig = inspect.signature(fsspec_get)
|
||||
if 'disable_tqdm' in fsspec_get_sig.parameters:
|
||||
fsspec_get(url,
|
||||
temp_file,
|
||||
storage_options=storage_options,
|
||||
desc=download_desc,
|
||||
disable_tqdm=disable_tqdm
|
||||
)
|
||||
else:
|
||||
fsspec_get(url, temp_file, storage_options=storage_options, desc=download_desc)
|
||||
else:
|
||||
http_get(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
)
|
||||
http_get_sig = inspect.signature(http_get)
|
||||
|
||||
if 'disable_tqdm' in http_get_sig.parameters:
|
||||
http_get(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
disable_tqdm=disable_tqdm,
|
||||
)
|
||||
else:
|
||||
http_get(
|
||||
url,
|
||||
temp_file=temp_file,
|
||||
proxies=proxies,
|
||||
resume_size=resume_size,
|
||||
headers=headers,
|
||||
cookies=cookies,
|
||||
max_retries=max_retries,
|
||||
desc=download_desc,
|
||||
)
|
||||
|
||||
logger.info(f'storing {url} in cache at {cache_path}')
|
||||
shutil.move(temp_file.name, cache_path)
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
datasets>=2.18.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
python-dateutil>=2.1
|
||||
scipy
|
||||
# latest version has some compatible issue.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
addict
|
||||
attrs
|
||||
datasets>=2.16.0,<2.19.0
|
||||
datasets>=2.18.0
|
||||
einops
|
||||
oss2
|
||||
Pillow
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import unittest
|
||||
|
||||
from modelscope import MsDataset
|
||||
@@ -7,9 +8,6 @@ from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
# Note: MODELSCOPE_DOMAIN is set to 'test.modelscope.cn' in the environment variable
|
||||
# TODO: ONLY FOR TEST ENVIRONMENT, to be replaced by the online domain
|
||||
|
||||
TEST_INNER_LEVEL = 1
|
||||
|
||||
|
||||
@@ -19,32 +17,33 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
||||
'skip test in current test level')
|
||||
def test_return_dataset_info_only(self):
|
||||
ds = MsDataset.load(
|
||||
'wangxingjun778test/aya_dataset_mini', dataset_info_only=True)
|
||||
print(f'>>output of test_return_dataset_info_only:\n {ds}')
|
||||
'wangxingjun778/aya_dataset_mini', dataset_info_only=True)
|
||||
logger.info(f'>>output of test_return_dataset_info_only:\n {ds}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_inner_fashion_mnist(self):
|
||||
# inner means the dataset is on the test.modelscope.cn environment
|
||||
ds = MsDataset.load(
|
||||
'xxxxtest0004/ms_test_0308_py',
|
||||
'wangxingjun778/ms_test_0308_py',
|
||||
subset_name='fashion_mnist',
|
||||
split='train')
|
||||
print(f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
|
||||
logger.info(
|
||||
f'>>output of test_inner_fashion_mnist:\n {next(iter(ds))}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_inner_clue(self):
|
||||
ds = MsDataset.load(
|
||||
'wangxingjun778test/clue', subset_name='afqmc', split='train')
|
||||
print(f'>>output of test_inner_clue:\n {next(iter(ds))}')
|
||||
'wangxingjun778/clue', subset_name='afqmc', split='train')
|
||||
logger.info(f'>>output of test_inner_clue:\n {next(iter(ds))}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_inner_cats_and_dogs_mini(self):
|
||||
ds = MsDataset.load(
|
||||
'wangxingjun778test/cats_and_dogs_mini', split='train')
|
||||
print(f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
|
||||
ds = MsDataset.load('wangxingjun778/cats_and_dogs_mini', split='train')
|
||||
logger.info(
|
||||
f'>>output of test_inner_cats_and_dogs_mini:\n {next(iter(ds))}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
@@ -53,14 +52,14 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
||||
# data/train-xxx-of-xxx.parquet; data/test-xxx-of-xxx.parquet
|
||||
# demographics/train-xxx-of-xxx.parquet
|
||||
|
||||
ds = MsDataset.load(
|
||||
'wangxingjun778test/aya_dataset_mini', split='train')
|
||||
print(f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
|
||||
ds = MsDataset.load('wangxingjun778/aya_dataset_mini', split='train')
|
||||
logger.info(
|
||||
f'>>output of test_inner_aya_dataset_mini:\n {next(iter(ds))}')
|
||||
|
||||
ds = MsDataset.load(
|
||||
'wangxingjun778test/aya_dataset_mini', subset_name='demographics')
|
||||
'wangxingjun778/aya_dataset_mini', subset_name='demographics')
|
||||
assert next(iter(ds['train']))
|
||||
print(
|
||||
logger.info(
|
||||
f">>output of test_inner_aya_dataset_mini:\n {next(iter(ds['train']))}"
|
||||
)
|
||||
|
||||
@@ -68,36 +67,46 @@ class GeneralMsDatasetTest(unittest.TestCase):
|
||||
'skip test in current test level')
|
||||
def test_inner_no_standard_imgs(self):
|
||||
infos = MsDataset.load(
|
||||
'xxxxtest0004/png_jpg_txt_test', dataset_info_only=True)
|
||||
'wangxingjun778/png_jpg_txt_test', dataset_info_only=True)
|
||||
assert infos['default']
|
||||
|
||||
ds = MsDataset.load('xxxxtest0004/png_jpg_txt_test', split='train')
|
||||
print(f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
|
||||
assert next(iter(ds))
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_inner_hf_pictures(self):
|
||||
ds = MsDataset.load('xxxxtest0004/hf_Pictures')
|
||||
print(ds)
|
||||
ds = MsDataset.load('wangxingjun778/png_jpg_txt_test', split='train')
|
||||
logger.info(
|
||||
f'>>>output of test_inner_no_standard_imgs: \n{next(iter(ds))}')
|
||||
assert next(iter(ds))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_inner_speech_yinpin(self):
|
||||
ds = MsDataset.load('xxxxtest0004/hf_lj_speech_yinpin_test')
|
||||
print(ds)
|
||||
ds = MsDataset.load('wangxingjun778/hf_lj_speech_yinpin_test')
|
||||
logger.info(ds)
|
||||
assert next(iter(ds))
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_inner_yuancheng_picture(self):
|
||||
ds = MsDataset.load(
|
||||
'xxxxtest0004/yuancheng_picture',
|
||||
'wangxingjun778/yuancheng_picture',
|
||||
subset_name='remote_images',
|
||||
split='train')
|
||||
print(next(iter(ds)))
|
||||
logger.info(next(iter(ds)))
|
||||
assert next(iter(ds))
|
||||
|
||||
@unittest.skipUnless(test_level() >= TEST_INNER_LEVEL,
|
||||
'skip test in current test level')
|
||||
def test_youku_mplug_dataset(self):
|
||||
# To test the Youku-AliceMind dataset with new sdk version
|
||||
ds = MsDataset.load(
|
||||
'modelscope/Youku-AliceMind',
|
||||
subset_name='classification',
|
||||
split='validation', # Options: train, test, validation
|
||||
use_streaming=True)
|
||||
|
||||
logger.info(next(iter(ds)))
|
||||
data_sample = next(iter(ds))
|
||||
|
||||
assert data_sample['video_id'][0]
|
||||
assert os.path.exists(data_sample['video_id:FILE'][0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -169,17 +169,6 @@ class MsDatasetTest(unittest.TestCase):
|
||||
'speech_asr_aishell1_trainsets', namespace='speech_asr')
|
||||
print(next(iter(ms_ds_asr['train'])))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@require_torch
|
||||
def test_to_torch_dataset_img(self):
|
||||
ms_image_train = MsDataset.load(
|
||||
'fixtures_image_utils', namespace='damotest', split='test')
|
||||
pt_dataset = ms_image_train.to_torch_dataset(
|
||||
preprocessors=ImgPreprocessor(image_path='file'))
|
||||
import torch
|
||||
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
|
||||
print(next(iter(dataloader)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@require_tf
|
||||
def test_to_tf_dataset_img(self):
|
||||
@@ -229,7 +218,7 @@ class MsDatasetTest(unittest.TestCase):
|
||||
print(data_example)
|
||||
assert data_example.values()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_streaming_load_img_object(self):
|
||||
"""Test case for iterating PIL object."""
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
@@ -238,7 +227,7 @@ class MsDatasetTest(unittest.TestCase):
|
||||
subset_name='default',
|
||||
namespace='huizheng',
|
||||
split='train',
|
||||
use_streaming=True)
|
||||
use_streaming=False)
|
||||
data_example = next(iter(dataset))
|
||||
print(data_example)
|
||||
assert data_example.values()
|
||||
@@ -247,7 +236,8 @@ class MsDatasetTest(unittest.TestCase):
|
||||
def test_to_ms_dataset(self):
|
||||
"""Test case for converting huggingface dataset to `MsDataset` instance."""
|
||||
from datasets.load import load_dataset
|
||||
hf_dataset = load_dataset('beans', split='train', streaming=True)
|
||||
hf_dataset = load_dataset(
|
||||
'AI-Lab-Makerere/beans', split='train', streaming=True)
|
||||
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
|
||||
data_example = next(iter(ms_dataset))
|
||||
print(data_example)
|
||||
|
||||
@@ -17,7 +17,7 @@ class Human3DAnimationTest(unittest.TestCase):
|
||||
human3d = pipeline(self.task, model=self.model_id)
|
||||
input = {
|
||||
'dataset_id': 'damo/3DHuman_synthetic_dataset',
|
||||
'case_id': '3f2a7538253e42a8',
|
||||
'case_id': '000146', # 3f2a7538253e42a8
|
||||
'action_dataset': 'damo/3DHuman_action_dataset',
|
||||
'action': 'SwingDancing',
|
||||
'save_dir': 'outputs',
|
||||
|
||||
Reference in New Issue
Block a user