Files
modelscope/tests/preprocessors/test_common.py
2022-08-16 12:04:07 +08:00

65 lines
1.7 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
import torch
from modelscope.preprocessors import (PREPROCESSORS, Compose, Filter,
Preprocessor, ToTensor)
class ComposeTest(unittest.TestCase):
def test_compose(self):
@PREPROCESSORS.register_module()
class Tmp1(Preprocessor):
def __call__(self, input):
input['tmp1'] = 'tmp1'
return input
@PREPROCESSORS.register_module()
class Tmp2(Preprocessor):
def __call__(self, input):
input['tmp2'] = 'tmp2'
return input
pipeline = [
dict(type='Tmp1'),
dict(type='Tmp2'),
]
trans = Compose(pipeline)
input = {}
output = trans(input)
self.assertEqual(output['tmp1'], 'tmp1')
self.assertEqual(output['tmp2'], 'tmp2')
class ToTensorTest(unittest.TestCase):
def test_totensor(self):
to_tensor_op = ToTensor(keys=['img'])
inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
inputs = to_tensor_op(inputs)
self.assertIsInstance(inputs['img'], torch.Tensor)
self.assertEqual(inputs['label'], 1)
self.assertEqual(inputs['path'], 'test.jpg')
class FilterTest(unittest.TestCase):
def test_filter(self):
filter_op = Filter(reserved_keys=['img', 'label'])
inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
inputs = filter_op(inputs)
self.assertIn('img', inputs)
self.assertIn('label', inputs)
self.assertNotIn('path', inputs)
if __name__ == '__main__':
unittest.main()