2022-05-19 22:18:35 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
|
|
|
|
import unittest
|
|
|
|
|
from typing import Any, Dict, List, Tuple, Union
|
|
|
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
import PIL
|
|
|
|
|
|
2022-06-09 20:16:26 +08:00
|
|
|
from modelscope.pipelines import Pipeline, pipeline
|
|
|
|
|
from modelscope.pipelines.builder import PIPELINES, add_default_pipeline_info
|
|
|
|
|
from modelscope.utils.constant import Tasks
|
|
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
|
from modelscope.utils.registry import default_group
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
Input = Union[str, 'PIL.Image', 'numpy.ndarray']
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CustomPipelineTest(unittest.TestCase):
|
|
|
|
|
|
|
|
|
|
def test_abstract(self):
|
|
|
|
|
|
|
|
|
|
@PIPELINES.register_module()
|
|
|
|
|
class CustomPipeline1(Pipeline):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
config_file: str = None,
|
|
|
|
|
model=None,
|
|
|
|
|
preprocessor=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(config_file, model, preprocessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
with self.assertRaises(TypeError):
|
|
|
|
|
CustomPipeline1()
|
|
|
|
|
|
|
|
|
|
def test_custom(self):
|
2022-06-16 11:15:09 +08:00
|
|
|
dummy_task = 'dummy-task'
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
@PIPELINES.register_module(
|
2022-06-16 11:15:09 +08:00
|
|
|
group_key=dummy_task, module_name='custom-image')
|
2022-05-19 22:18:35 +08:00
|
|
|
class CustomImagePipeline(Pipeline):
|
|
|
|
|
|
|
|
|
|
def __init__(self,
|
|
|
|
|
config_file: str = None,
|
|
|
|
|
model=None,
|
|
|
|
|
preprocessor=None,
|
|
|
|
|
**kwargs):
|
|
|
|
|
super().__init__(config_file, model, preprocessor, **kwargs)
|
|
|
|
|
|
|
|
|
|
def preprocess(self, input: Union[str,
|
|
|
|
|
'PIL.Image']) -> Dict[str, Any]:
|
|
|
|
|
""" Provide default implementation based on preprocess_cfg and user can reimplement it
|
|
|
|
|
|
|
|
|
|
"""
|
|
|
|
|
if not isinstance(input, PIL.Image.Image):
|
2022-06-09 20:16:26 +08:00
|
|
|
from modelscope.preprocessors import load_image
|
2022-05-19 22:18:35 +08:00
|
|
|
data_dict = {'img': load_image(input), 'url': input}
|
|
|
|
|
else:
|
|
|
|
|
data_dict = {'img': input}
|
|
|
|
|
return data_dict
|
|
|
|
|
|
|
|
|
|
def forward(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
""" Provide default implementation using self.model and user can reimplement it
|
|
|
|
|
"""
|
|
|
|
|
outputs = {}
|
|
|
|
|
if 'url' in inputs:
|
|
|
|
|
outputs['filename'] = inputs['url']
|
|
|
|
|
img = inputs['img']
|
|
|
|
|
new_image = img.resize((img.width // 2, img.height // 2))
|
2022-06-16 11:15:09 +08:00
|
|
|
outputs['output_png'] = np.array(new_image)
|
2022-05-19 22:18:35 +08:00
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
self.assertTrue('custom-image' in PIPELINES.modules[default_group])
|
2022-06-16 11:15:09 +08:00
|
|
|
add_default_pipeline_info(dummy_task, 'custom-image', overwrite=True)
|
2022-05-19 22:18:35 +08:00
|
|
|
pipe = pipeline(pipeline_name='custom-image')
|
2022-06-16 11:15:09 +08:00
|
|
|
pipe2 = pipeline(dummy_task)
|
2022-05-19 22:18:35 +08:00
|
|
|
self.assertTrue(type(pipe) is type(pipe2))
|
|
|
|
|
|
|
|
|
|
img_url = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.' \
|
|
|
|
|
'aliyuncs.com/data/test/images/image1.jpg'
|
|
|
|
|
output = pipe(img_url)
|
|
|
|
|
self.assertEqual(output['filename'], img_url)
|
2022-06-16 11:15:09 +08:00
|
|
|
self.assertEqual(output['output_png'].shape, (318, 512, 3))
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
outputs = pipe([img_url for i in range(4)])
|
|
|
|
|
self.assertEqual(len(outputs), 4)
|
|
|
|
|
for out in outputs:
|
|
|
|
|
self.assertEqual(out['filename'], img_url)
|
2022-06-16 11:15:09 +08:00
|
|
|
self.assertEqual(out['output_png'].shape, (318, 512, 3))
|
2022-05-19 22:18:35 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|