mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 20:19:51 +01:00
update ut
This commit is contained in:
@@ -330,6 +330,7 @@ class IterableDatasetBuilder(csv.Csv):
|
||||
|
||||
super().__init__(
|
||||
cache_dir=self.cache_build_dir,
|
||||
dataset_name=self.dataset_name,
|
||||
config_name=self.namespace,
|
||||
hash=sub_dir_hash,
|
||||
data_files=None, # TODO: self.meta_data_files,
|
||||
|
||||
@@ -268,17 +268,15 @@ class MsDataset:
|
||||
return dataset_inst
|
||||
# Load from the huggingface hub
|
||||
elif hub == Hubs.huggingface:
|
||||
dataset_inst = RemoteDataLoaderManager(
|
||||
dataset_context_config).load_dataset(
|
||||
RemoteDataLoaderType.HF_DATA_LOADER)
|
||||
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
|
||||
if isinstance(dataset_inst, MsDataset):
|
||||
dataset_inst._dataset_context_config = dataset_context_config
|
||||
if custom_cfg:
|
||||
dataset_inst.to_custom_dataset(
|
||||
custom_cfg=custom_cfg, **config_kwargs)
|
||||
dataset_inst.is_custom = True
|
||||
return dataset_inst
|
||||
from datasets import load_dataset
|
||||
return load_dataset(
|
||||
dataset_name,
|
||||
name=subset_name,
|
||||
split=split,
|
||||
streaming=use_streaming,
|
||||
download_mode=download_mode.value,
|
||||
**config_kwargs)
|
||||
|
||||
# Load from the modelscope hub
|
||||
elif hub == Hubs.modelscope:
|
||||
|
||||
|
||||
@@ -169,17 +169,6 @@ class MsDatasetTest(unittest.TestCase):
|
||||
'speech_asr_aishell1_trainsets', namespace='speech_asr')
|
||||
print(next(iter(ms_ds_asr['train'])))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@require_torch
|
||||
def test_to_torch_dataset_img(self):
|
||||
ms_image_train = MsDataset.load(
|
||||
'fixtures_image_utils', namespace='damotest', split='test')
|
||||
pt_dataset = ms_image_train.to_torch_dataset(
|
||||
preprocessors=ImgPreprocessor(image_path='file'))
|
||||
import torch
|
||||
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
|
||||
print(next(iter(dataloader)))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@require_tf
|
||||
def test_to_tf_dataset_img(self):
|
||||
@@ -229,7 +218,7 @@ class MsDatasetTest(unittest.TestCase):
|
||||
print(data_example)
|
||||
assert data_example.values()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
|
||||
def test_streaming_load_img_object(self):
|
||||
"""Test case for iterating PIL object."""
|
||||
from PIL.PngImagePlugin import PngImageFile
|
||||
@@ -238,7 +227,7 @@ class MsDatasetTest(unittest.TestCase):
|
||||
subset_name='default',
|
||||
namespace='huizheng',
|
||||
split='train',
|
||||
use_streaming=True)
|
||||
use_streaming=False)
|
||||
data_example = next(iter(dataset))
|
||||
print(data_example)
|
||||
assert data_example.values()
|
||||
@@ -247,7 +236,8 @@ class MsDatasetTest(unittest.TestCase):
|
||||
def test_to_ms_dataset(self):
|
||||
"""Test case for converting huggingface dataset to `MsDataset` instance."""
|
||||
from datasets.load import load_dataset
|
||||
hf_dataset = load_dataset('beans', split='train', streaming=True)
|
||||
hf_dataset = load_dataset(
|
||||
'AI-Lab-Makerere/beans', split='train', streaming=True)
|
||||
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
|
||||
data_example = next(iter(ms_dataset))
|
||||
print(data_example)
|
||||
|
||||
Reference in New Issue
Block a user