mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9644184 * fix ditributed training and eval
65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
import unittest
|
|
|
|
import torch
|
|
|
|
from modelscope.preprocessors import (PREPROCESSORS, Compose, Filter,
|
|
Preprocessor, ToTensor)
|
|
|
|
|
|
class ComposeTest(unittest.TestCase):
|
|
|
|
def test_compose(self):
|
|
|
|
@PREPROCESSORS.register_module()
|
|
class Tmp1(Preprocessor):
|
|
|
|
def __call__(self, input):
|
|
input['tmp1'] = 'tmp1'
|
|
return input
|
|
|
|
@PREPROCESSORS.register_module()
|
|
class Tmp2(Preprocessor):
|
|
|
|
def __call__(self, input):
|
|
input['tmp2'] = 'tmp2'
|
|
return input
|
|
|
|
pipeline = [
|
|
dict(type='Tmp1'),
|
|
dict(type='Tmp2'),
|
|
]
|
|
trans = Compose(pipeline)
|
|
|
|
input = {}
|
|
output = trans(input)
|
|
self.assertEqual(output['tmp1'], 'tmp1')
|
|
self.assertEqual(output['tmp2'], 'tmp2')
|
|
|
|
|
|
class ToTensorTest(unittest.TestCase):
|
|
|
|
def test_totensor(self):
|
|
to_tensor_op = ToTensor(keys=['img'])
|
|
inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
|
|
inputs = to_tensor_op(inputs)
|
|
self.assertIsInstance(inputs['img'], torch.Tensor)
|
|
self.assertEqual(inputs['label'], 1)
|
|
self.assertEqual(inputs['path'], 'test.jpg')
|
|
|
|
|
|
class FilterTest(unittest.TestCase):
|
|
|
|
def test_filter(self):
|
|
filter_op = Filter(reserved_keys=['img', 'label'])
|
|
inputs = {'img': [1, 2, 3], 'label': 1, 'path': 'test.jpg'}
|
|
inputs = filter_op(inputs)
|
|
self.assertIn('img', inputs)
|
|
self.assertIn('label', inputs)
|
|
self.assertNotIn('path', inputs)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|