[to #42322933]add semantic-segmentation task output is numpy mask for demo-service

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10265856
This commit is contained in:
wendi.hwd
2022-09-27 15:01:05 +08:00
committed by Yingda Chen
parent a17b49db97
commit 6d29460eb8
5 changed files with 13 additions and 10 deletions

View File

@@ -13,7 +13,8 @@ from modelscope.utils.constant import ModelFile, Tasks
from .models import U2NET
@MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection)
@MODELS.register_module(
Tasks.semantic_segmentation, module_name=Models.detection)
class SalientDetection(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):

View File

@@ -151,6 +151,12 @@ TASK_OUTPUTS = {
Tasks.image_segmentation:
[OutputKeys.SCORES, OutputKeys.LABELS, OutputKeys.MASKS],
# semantic segmentation result for single sample
# {
# "masks": [np.array # 2D array containing only 0, 255]
# }
Tasks.semantic_segmentation: [OutputKeys.MASKS],
# image matting result for single sample
# {
# "output_img": np.array with shape(h, w, 4)

View File

@@ -9,7 +9,7 @@ from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
Tasks.image_segmentation, module_name=Pipelines.salient_detection)
Tasks.semantic_segmentation, module_name=Pipelines.salient_detection)
class ImageSalientDetectionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
@@ -39,9 +39,5 @@ class ImageSalientDetectionPipeline(Pipeline):
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
data = self.model.postprocess(inputs)
outputs = {
OutputKeys.SCORES: None,
OutputKeys.LABELS: None,
OutputKeys.MASKS: data
}
outputs = {OutputKeys.MASKS: data}
return outputs

View File

@@ -38,6 +38,7 @@ class CVTasks(object):
image_object_detection = 'image-object-detection'
image_segmentation = 'image-segmentation'
semantic_segmentation = 'semantic-segmentation'
portrait_matting = 'portrait-matting'
text_driven_segmentation = 'text-driven-segmentation'
shop_segmentation = 'shop-segmentation'

View File

@@ -11,17 +11,16 @@ from modelscope.utils.test_utils import test_level
class SalientDetectionTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.image_segmentation
self.task = Tasks.semantic_segmentation
self.model_id = 'damo/cv_u2net_salient-detection'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_salient_detection(self):
input_location = 'data/test/images/image_salient_detection.jpg'
model_id = 'damo/cv_u2net_salient-detection'
salient_detect = pipeline(Tasks.image_segmentation, model=model_id)
salient_detect = pipeline(Tasks.semantic_segmentation, model=model_id)
result = salient_detect(input_location)
import cv2
# result[OutputKeys.MASKS] is salient map result,other keys are not used
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])
@unittest.skip('demo compatibility test is only enabled on a needed-basis')