Files
modelscope/tests/utils/test_registry.py

95 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.utils.constant import Tasks
from modelscope.utils.registry import Registry, build_from_cfg, default_group
class RegistryTest(unittest.TestCase):
def test_register_class_no_task(self):
MODELS = Registry('models')
self.assertTrue(MODELS.name == 'models')
self.assertTrue(default_group in MODELS.modules)
self.assertTrue(MODELS.modules[default_group] == {})
self.assertEqual(len(MODELS.modules), 1)
@MODELS.register_module(module_name='cls-resnet')
class ResNetForCls(object):
pass
self.assertTrue(default_group in MODELS.modules)
self.assertTrue(MODELS.get('cls-resnet') is ResNetForCls)
def test_register_class_with_task(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass
self.assertTrue(Tasks.image_classification in MODELS.modules)
self.assertTrue(
MODELS.get('SwinT', Tasks.image_classification) is SwinTForCls)
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
self.assertTrue(Tasks.sentiment_analysis in MODELS.modules)
self.assertTrue(
MODELS.get('Bert', Tasks.sentiment_analysis) is
BertForSentimentAnalysis)
@MODELS.register_module(Tasks.image_object_detection)
class DETR(object):
pass
self.assertTrue(Tasks.image_object_detection in MODELS.modules)
self.assertTrue(
MODELS.get('DETR', Tasks.image_object_detection) is DETR)
self.assertEqual(len(MODELS.modules), 4)
def test_list(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
MODELS.list()
print(MODELS)
def test_build(self):
MODELS = Registry('models')
@MODELS.register_module(Tasks.image_classification, 'SwinT')
class SwinTForCls(object):
pass
@MODELS.register_module(Tasks.sentiment_analysis, 'Bert')
class BertForSentimentAnalysis(object):
pass
cfg = dict(type='SwinT')
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
self.assertTrue(isinstance(model, SwinTForCls))
cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.sentiment_analysis)
self.assertTrue(isinstance(model, BertForSentimentAnalysis))
with self.assertRaises(KeyError):
cfg = dict(type='Bert')
model = build_from_cfg(cfg, MODELS, Tasks.image_classification)
if __name__ == '__main__':
unittest.main()