[to #42670107]pydataset fetch data from datahub

* pydataset fetch data from datahub
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9060856
This commit is contained in:
feiwu.yfw
2022-06-21 11:10:28 +08:00
committed by yingda.chen
parent 97a0087976
commit c7238a470b
9 changed files with 581 additions and 104 deletions

View File

@@ -2,42 +2,111 @@ import unittest
import datasets as hfdata
from modelscope.models import Model
from modelscope.preprocessors import SequenceClassificationPreprocessor
from modelscope.preprocessors.base import Preprocessor
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import Hubs
from modelscope.utils.test_utils import require_tf, require_torch, test_level
class ImgPreprocessor(Preprocessor):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.path_field = kwargs.pop('image_path', 'image_path')
self.width = kwargs.pop('width', 'width')
self.height = kwargs.pop('height', 'width')
def __call__(self, data):
import cv2
image_path = data.get(self.path_field)
if not image_path:
return None
img = cv2.imread(image_path)
return {
'image':
cv2.resize(img,
(data.get(self.height, 128), data.get(self.width, 128)))
}
class PyDatasetTest(unittest.TestCase):
def setUp(self):
# ds1 initialized from in memory json
self.json_data = {
'dummy': [{
'a': i,
'x': i * 10,
'c': i * 100
} for i in range(1, 11)]
}
hfds1 = hfdata.Dataset.from_dict(self.json_data)
self.ds1 = PyDataset.from_hf_dataset(hfds1)
def test_ds_basic(self):
ms_ds_full = PyDataset.load('squad')
ms_ds_full_hf = hfdata.load_dataset('squad')
ms_ds_train = PyDataset.load('squad', split='train')
ms_ds_train_hf = hfdata.load_dataset('squad', split='train')
ms_image_train = PyDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
self.assertEqual(ms_ds_full['train'][0], ms_ds_full_hf['train'][0])
self.assertEqual(ms_ds_full['validation'][0],
ms_ds_full_hf['validation'][0])
self.assertEqual(ms_ds_train[0], ms_ds_train_hf[0])
print(next(iter(ms_ds_full['train'])))
print(next(iter(ms_ds_train)))
print(next(iter(ms_image_train)))
# ds2 initialized from hg hub
hfds2 = hfdata.load_dataset(
'glue', 'mrpc', revision='2.0.0', split='train')
self.ds2 = PyDataset.from_hf_dataset(hfds2)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_text(self):
model_id = 'damo/bert-base-sst2'
nlp_model = Model.from_pretrained(model_id)
preprocessor = SequenceClassificationPreprocessor(
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
pt_dataset = ms_ds_train.to_torch_dataset(preprocessors=preprocessor)
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
def tearDown(self):
pass
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_text(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
model_id = 'damo/bert-base-sst2'
nlp_model = Model.from_pretrained(model_id)
preprocessor = SequenceClassificationPreprocessor(
nlp_model.model_dir,
first_sequence='context',
second_sequence=None)
ms_ds_train = PyDataset.load('squad', split='train')
tf_dataset = ms_ds_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=preprocessor,
drop_remainder=True)
print(next(iter(tf_dataset)))
def test_to_hf_dataset(self):
hfds = self.ds1.to_hf_dataset()
hfds1 = hfdata.Dataset.from_dict(self.json_data)
self.assertEqual(hfds.data, hfds1.data)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_torch
def test_to_torch_dataset_img(self):
ms_image_train = PyDataset.from_hf_dataset(
hfdata.load_dataset('beans', split='train'))
pt_dataset = ms_image_train.to_torch_dataset(
preprocessors=ImgPreprocessor(
image_path='image_file_path', label='labels'))
import torch
dataloader = torch.utils.data.DataLoader(pt_dataset, batch_size=5)
print(next(iter(dataloader)))
# simple map function
hfds = hfds.map(lambda e: {'new_feature': e['dummy']['a']})
self.assertEqual(len(hfds['new_feature']), 10)
hfds2 = self.ds2.to_hf_dataset()
self.assertTrue(hfds2[0]['sentence1'].startswith('Amrozi'))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
@require_tf
def test_to_tf_dataset_img(self):
import tensorflow as tf
tf.compat.v1.enable_eager_execution()
ms_image_train = PyDataset.load('beans', split='train')
tf_dataset = ms_image_train.to_tf_dataset(
batch_size=5,
shuffle=True,
preprocessors=ImgPreprocessor(image_path='image_file_path'),
drop_remainder=True,
label_cols='labels')
print(next(iter(tf_dataset)))
if __name__ == '__main__':