Files
modelscope/tests/msdatasets/test_ms_dataset.py
xingjun.wxj 43edddd31f [to #42322933] msdataset module refactor and add 1230 features
1. 优化本地数据集加载链路  
2. local与remote解耦,无网络环境下也可以使用SDK  
3. 升级hf datasets及其相关依赖到最新版(2.7.0+)
4. 解决元数据感知不到数据文件变更的问题  
5. 系统分层设计
6. 本地缓存管理问题  
7. 优化error log输出信息  
8. 支持streaming load	
* a. 支持数据文件为zip格式的streaming
* b. 支持Image/Text/Audio/Biodata等格式数据集的iter
* c. 兼容训练数据在meta中的历史数据集的streaming load
* d. 支持数据文件为文件夹格式的streaming load

9. finetune任务串接进一步规范
* a. 避免出现to_hf_dataset这种使用,将常用的tf相关的func封装起来  
* b. 去掉了跟hf混用的一些逻辑,统一包装到MsDataset里面

10. 超大数据集场景优化
* a. list oss objects: 直接拉取meta中的csv mapping,不需要做 list_oss_objects的api调用(前述提交已实现)
* b. 优化sts过期加载问题(前述提交已实现)

11. 支持dataset_name格式为:namespace/dataset_name的输入方式

参考Aone链接: https://aone.alibaba-inc.com/v2/project/1162242/task/46262894
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11264406
2023-01-10 07:01:34 +08:00

154 lines
5.8 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.models import Model
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors import TextClassificationTransformersPreprocessor
from modelscope.preprocessors.base import Preprocessor
from modelscope.utils.constant import DEFAULT_DATASET_NAMESPACE, DownloadMode
from modelscope.utils.test_utils import require_tf, require_torch, test_level
class ImgPreprocessor(Preprocessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.path_field = kwargs.pop('image_path', 'image_path')
self.width = kwargs.pop('width', 'width')
self.height = kwargs.pop('height', 'width')
def __call__(self, data):
import cv2
image_path = data.get(self.path_field)
if not image_path:
return None
img = cv2.imread(image_path)
return {
'image':
cv2.resize(img,
(data.get(self.height, 128), data.get(self.width, 128)))
}
class MsDatasetTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_movie_scene_seg_toydata(self):
ms_ds_train = MsDataset.load('movie_scene_seg_toydata', split='train')
print(ms_ds_train._hf_ds.config_kwargs)
assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_coco(self):
ms_ds_train = MsDataset.load(
'pets_small',
namespace=DEFAULT_DATASET_NAMESPACE,
download_mode=DownloadMode.FORCE_REDOWNLOAD,
split='train')
print(ms_ds_train.config_kwargs)
assert next(iter(ms_ds_train.config_kwargs['split_config'].values()))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ms_csv_basic(self):
ms_ds_train = MsDataset.load(
'clue', subset_name='afqmc',
split='train').to_hf_dataset().select(range(5))
print(next(iter(ms_ds_train)))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_ds_basic(self):
ms_ds_full = MsDataset.load(
'xcopa', subset_name='translation-et', namespace='damotest')
ms_ds = MsDataset.load(
'xcopa',
subset_name='translation-et',
namespace='damotest',
split='test')
print(next(iter(ms_ds_full['test'])))
print(next(iter(ms_ds)))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_text(self):
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
nlp_model = Model.from_pretrained(model_id)
preprocessor = TextClassificationTransformersPreprocessor(
nlp_model.model_dir,
first_sequence='premise',
second_sequence=None,
padding='max_length')
ms_ds_train = MsDataset.load(
'xcopa',
subset_name='translation-et',
namespace='damotest',
split='test')
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_text(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
model_id = 'damo/nlp_structbert_sentence-similarity_chinese-tiny'
nlp_model = Model.from_pretrained(model_id)
preprocessor = TextClassificationTransformersPreprocessor(
nlp_model.model_dir,
first_sequence='premise',
second_sequence=None)
ms_ds_train = MsDataset.load(
'xcopa',
subset_name='translation-et',
namespace='damotest',
split='test')
tf_dataset = ms_ds_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=preprocessor,
drop_remainder=True)
print(next(iter(tf_dataset)))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = MsDataset.load(
'fixtures_image_utils', namespace='damotest', split='test')
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(image_path='file'))
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_img(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
ms_image_train = MsDataset.load(
'fixtures_image_utils', namespace='damotest', split='test')
tf_dataset = ms_image_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=ImgPreprocessor(image_path='file'),
drop_remainder=True,
)
print(next(iter(tf_dataset)))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_streaming_load_coco(self):
small_coco_for_test = MsDataset.load(
dataset_name='EasyCV/small_coco_for_test',
split='train',
use_streaming=True,
download_mode=DownloadMode.FORCE_REDOWNLOAD)
dataset_sample_dict = next(iter(small_coco_for_test))
print(dataset_sample_dict)
assert dataset_sample_dict.values()
if __name__ == '__main__':
unittest.main()