update ut

This commit is contained in:
xingjun.wang
2024-07-21 02:52:37 +08:00
parent db6add8e1c
commit f8285e1fd7
3 changed files with 14 additions and 25 deletions

View File

@@ -330,6 +330,7 @@ class IterableDatasetBuilder(csv.Csv):
super().__init__(
cache_dir=self.cache_build_dir,
dataset_name=self.dataset_name,
config_name=self.namespace,
hash=sub_dir_hash,
data_files=None, # TODO: self.meta_data_files,

View File

@@ -268,17 +268,15 @@ class MsDataset:
return dataset_inst
# Load from the huggingface hub
elif hub == Hubs.huggingface:
dataset_inst = RemoteDataLoaderManager(
dataset_context_config).load_dataset(
RemoteDataLoaderType.HF_DATA_LOADER)
dataset_inst = MsDataset.to_ms_dataset(dataset_inst, target=target)
if isinstance(dataset_inst, MsDataset):
dataset_inst._dataset_context_config = dataset_context_config
if custom_cfg:
dataset_inst.to_custom_dataset(
custom_cfg=custom_cfg, **config_kwargs)
dataset_inst.is_custom = True
return dataset_inst
from datasets import load_dataset
return load_dataset(
dataset_name,
name=subset_name,
split=split,
streaming=use_streaming,
download_mode=download_mode.value,
**config_kwargs)
# Load from the modelscope hub
elif hub == Hubs.modelscope:

View File

@@ -169,17 +169,6 @@ class MsDatasetTest(unittest.TestCase):
'speech_asr_aishell1_trainsets', namespace='speech_asr')
print(next(iter(ms_ds_asr['train'])))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = MsDataset.load(
'fixtures_image_utils', namespace='damotest', split='test')
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(image_path='file'))
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_img(self):
@@ -229,7 +218,7 @@ class MsDatasetTest(unittest.TestCase):
print(data_example)
assert data_example.values()
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@unittest.skipUnless(test_level() >= 3, 'skip test in current test level')
def test_streaming_load_img_object(self):
"""Test case for iterating PIL object."""
from PIL.PngImagePlugin import PngImageFile
@@ -238,7 +227,7 @@ class MsDatasetTest(unittest.TestCase):
subset_name='default',
namespace='huizheng',
split='train',
use_streaming=True)
use_streaming=False)
data_example = next(iter(dataset))
print(data_example)
assert data_example.values()
@@ -247,7 +236,8 @@ class MsDatasetTest(unittest.TestCase):
def test_to_ms_dataset(self):
"""Test case for converting huggingface dataset to `MsDataset` instance."""
from datasets.load import load_dataset
hf_dataset = load_dataset('beans', split='train', streaming=True)
hf_dataset = load_dataset(
'AI-Lab-Makerere/beans', split='train', streaming=True)
ms_dataset = MsDataset.to_ms_dataset(hf_dataset)
data_example = next(iter(ms_dataset))
print(data_example)