Files
modelscope/tests/pipelines/test_image_matting.py

76 lines
2.9 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import os.path as osp
import shutil
import tempfile
import unittest
import cv2
from modelscope.fileio import File
from modelscope.pipelines import pipeline
from modelscope.pydatasets import PyDataset
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.hub import get_model_cache_dir
class ImageMattingTest(unittest.TestCase):
def setUp(self) -> None:
self.model_id = 'damo/cv_unet_image-matting_damo'
# switch to False if downloading everytime is not desired
purge_cache = True
if purge_cache:
shutil.rmtree(
get_model_cache_dir(self.model_id), ignore_errors=True)
@unittest.skip('deprecated, download model from model hub instead')
def test_run_with_direct_file_download(self):
model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \
'.com/data/test/maas/image_matting/matting_person.pb'
with tempfile.TemporaryDirectory() as tmp_dir:
model_file = osp.join(tmp_dir, ModelFile.TF_GRAPH_FILE)
with open(model_file, 'wb') as ofile:
ofile.write(File.read(model_path))
img_matting = pipeline(Tasks.image_matting, model=tmp_dir)
result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
cv2.imwrite('result.png', result['output_png'])
def test_run_with_dataset(self):
input_location = [
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
]
# alternatively:
# input_location = '/dir/to/images'
dataset = PyDataset.load(input_location, target='image')
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
# note that for dataset output, the inference-output is a Generator that can be iterated.
result = img_matting(dataset)
cv2.imwrite('result.png', next(result)['output_png'])
print(f'Output written to {osp.abspath("result.png")}')
def test_run_modelhub(self):
img_matting = pipeline(Tasks.image_matting, model=self.model_id)
result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')
def test_run_modelhub_default_model(self):
img_matting = pipeline(Tasks.image_matting)
result = img_matting(
'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png'
)
cv2.imwrite('result.png', result['output_png'])
print(f'Output written to {osp.abspath("result.png")}')
if __name__ == '__main__':
unittest.main()