Files
modelscope/tests/pipelines/test_builder.py
wenmeng.zwm 4814b198f0 [to #43112534] taskdataset refine and auto placement for data and model
* refine taskdataset interface
 * add device placement for trainer
 * add device placement for pipeline
 * add config checker and fix model placement bug
 * fix cycling import
 * refactor model init for translation_pipeline
 * cv pipelines support kwargs


Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9463076
2022-07-23 11:08:43 +08:00

91 lines
2.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import unittest
from asyncio import Task
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import PIL
from modelscope.fileio import io
from modelscope.models.base import Model
from modelscope.pipelines import Pipeline, pipeline
from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
from modelscope.utils.constant import (ConfigFields, Frameworks, ModelFile,
Tasks)
from modelscope.utils.logger import get_logger
from modelscope.utils.registry import default_group
logger = get_logger()
@PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='custom_single_model')
class CustomSingleModelPipeline(Pipeline):
def __init__(self,
config_file: str = None,
model: List[Union[str, Model]] = None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)
assert isinstance(model, str), 'model is not str'
print(model)
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return super().postprocess(inputs)
@PIPELINES.register_module(
group_key=Tasks.image_tagging, module_name='model1_model2')
class CustomMultiModelPipeline(Pipeline):
def __init__(self,
config_file: str = None,
model: List[Union[str, Model]] = None,
preprocessor=None,
**kwargs):
super().__init__(config_file, model, preprocessor, **kwargs)
assert isinstance(model, list), 'model is not list'
for m in model:
assert isinstance(m, str), 'submodel is not str'
print(m)
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return super().postprocess(inputs)
class PipelineInterfaceTest(unittest.TestCase):
def prepare_dir(self, dirname, pipeline_name):
if not os.path.exists(dirname):
os.makedirs(dirname)
cfg_file = os.path.join(dirname, ModelFile.CONFIGURATION)
cfg = {
ConfigFields.framework: Frameworks.torch,
ConfigFields.task: Tasks.image_tagging,
ConfigFields.pipeline: {
'type': pipeline_name,
}
}
io.dump(cfg, cfg_file)
def setUp(self) -> None:
self.prepare_dir('/tmp/custom_single_model', 'custom_single_model')
self.prepare_dir('/tmp/model1', 'model1_model2')
self.prepare_dir('/tmp/model2', 'model1_model2')
def test_single_model(self):
pipe = pipeline(Tasks.image_tagging, model='/tmp/custom_single_model')
assert isinstance(pipe, CustomSingleModelPipeline)
def test_multi_model(self):
pipe = pipeline(
Tasks.image_tagging, model=['/tmp/model1', '/tmp/model2'])
assert isinstance(pipe, CustomMultiModelPipeline)
if __name__ == '__main__':
unittest.main()