# Copyright (c) Alibaba, Inc. and its affiliates. import os.path as osp import tempfile import unittest from typing import Any, Dict, List, Tuple, Union import cv2 import numpy as np import PIL from maas_lib.fileio import File from maas_lib.pipelines import pipeline from maas_lib.utils.constant import Tasks class ImageMattingTest(unittest.TestCase): def test_run(self): model_path = 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs' \ '.com/data/test/maas/image_matting/matting_person.pb' with tempfile.TemporaryDirectory() as tmp_dir: model_file = osp.join(tmp_dir, 'matting_person.pb') with open(model_file, 'wb') as ofile: ofile.write(File.read(model_path)) img_matting = pipeline(Tasks.image_matting, model=tmp_dir) result = img_matting( 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' ) cv2.imwrite('result.png', result['output_png']) def test_run_modelhub(self): img_matting = pipeline( Tasks.image_matting, model='damo/image-matting-person') result = img_matting( 'http://pai-vision-data-hz.oss-cn-zhangjiakou.aliyuncs.com/data/test/maas/image_matting/test.png' ) cv2.imwrite('result.png', result['output_png']) if __name__ == '__main__': unittest.main()