mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933]panoptic segmentation 模型接入
panoptic segmentation 模型接入
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9758389
This commit is contained in:
3
data/test/images/image_panoptic_segmentation.jpg
Normal file
3
data/test/images/image_panoptic_segmentation.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:59b1da30af12f76b691990363e0d221050a59cf53fc4a97e776bcb00228c6c2a
|
||||
size 245864
|
||||
@@ -20,6 +20,7 @@ class Models(object):
|
||||
product_retrieval_embedding = 'product-retrieval-embedding'
|
||||
body_2d_keypoints = 'body-2d-keypoints'
|
||||
crowd_counting = 'HRNetCrowdCounting'
|
||||
panoptic_segmentation = 'swinL-panoptic-segmentation'
|
||||
image_reid_person = 'passvitb'
|
||||
video_summarization = 'pgl-video-summarization'
|
||||
|
||||
@@ -114,6 +115,7 @@ class Pipelines(object):
|
||||
tinynas_classification = 'tinynas-classification'
|
||||
crowd_counting = 'hrnet-crowd-counting'
|
||||
video_single_object_tracking = 'ostrack-vitb-video-single-object-tracking'
|
||||
image_panoptic_segmentation = 'image-panoptic-segmentation'
|
||||
video_summarization = 'googlenet_pgl_video_summarization'
|
||||
image_reid_person = 'passvitb-image-reid-person'
|
||||
|
||||
|
||||
@@ -3,8 +3,9 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
cartoon, cmdssl_video_embedding, crowd_counting, face_detection,
|
||||
face_generation, image_classification, image_color_enhance,
|
||||
image_colorization, image_denoise, image_instance_segmentation,
|
||||
image_portrait_enhancement, image_reid_person,
|
||||
image_to_image_generation, image_to_image_translation,
|
||||
object_detection, product_retrieval_embedding,
|
||||
salient_detection, super_resolution,
|
||||
video_single_object_tracking, video_summarization, virual_tryon)
|
||||
image_panoptic_segmentation, image_portrait_enhancement,
|
||||
image_reid_person, image_to_image_generation,
|
||||
image_to_image_translation, object_detection,
|
||||
product_retrieval_embedding, salient_detection,
|
||||
super_resolution, video_single_object_tracking,
|
||||
video_summarization, virual_tryon)
|
||||
|
||||
22
modelscope/models/cv/image_panoptic_segmentation/__init__.py
Normal file
22
modelscope/models/cv/image_panoptic_segmentation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .panseg_model import SwinLPanopticSegmentation
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'panseg_model': ['SwinLPanopticSegmentation'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,54 @@
|
||||
import os.path as osp
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_segmentation, module_name=Models.panoptic_segmentation)
|
||||
class SwinLPanopticSegmentation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, **kwargs)
|
||||
|
||||
from mmcv.runner import load_checkpoint
|
||||
import mmcv
|
||||
from mmdet.models import build_detector
|
||||
|
||||
config = osp.join(model_dir, 'config.py')
|
||||
|
||||
cfg = mmcv.Config.fromfile(config)
|
||||
if 'pretrained' in cfg.model:
|
||||
cfg.model.pretrained = None
|
||||
elif 'init_cfg' in cfg.model.backbone:
|
||||
cfg.model.backbone.init_cfg = None
|
||||
|
||||
# build model
|
||||
cfg.model.train_cfg = None
|
||||
self.model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
|
||||
|
||||
# load model
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
checkpoint = load_checkpoint(
|
||||
self.model, model_path, map_location='cpu')
|
||||
|
||||
self.CLASSES = checkpoint['meta']['CLASSES']
|
||||
self.num_classes = len(self.CLASSES)
|
||||
self.cfg = cfg
|
||||
|
||||
def inference(self, data):
|
||||
"""data is dict,contain img and img_metas,follow with mmdet."""
|
||||
|
||||
with torch.no_grad():
|
||||
results = self.model(return_loss=False, rescale=True, **data)
|
||||
return results
|
||||
|
||||
def forward(self, Inputs):
|
||||
import pdb
|
||||
pdb.set_trace()
|
||||
return self.model(**Inputs)
|
||||
@@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||
from .image_instance_segmentation_pipeline import ImageInstanceSegmentationPipeline
|
||||
from .image_matting_pipeline import ImageMattingPipeline
|
||||
from .image_panoptic_segmentation_pipeline import ImagePanopticSegmentationPipeline
|
||||
from .image_portrait_enhancement_pipeline import ImagePortraitEnhancementPipeline
|
||||
from .image_reid_person_pipeline import ImageReidPersonPipeline
|
||||
from .image_style_transfer_pipeline import ImageStyleTransferPipeline
|
||||
@@ -37,6 +38,7 @@ if TYPE_CHECKING:
|
||||
from .tinynas_classification_pipeline import TinynasClassificationPipeline
|
||||
from .video_category_pipeline import VideoCategoryPipeline
|
||||
from .virtual_try_on_pipeline import VirtualTryonPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'action_recognition_pipeline': ['ActionRecognitionPipeline'],
|
||||
@@ -59,6 +61,8 @@ else:
|
||||
'image_instance_segmentation_pipeline':
|
||||
['ImageInstanceSegmentationPipeline'],
|
||||
'image_matting_pipeline': ['ImageMattingPipeline'],
|
||||
'image_panoptic_segmentation_pipeline':
|
||||
['ImagePanopticSegmentationPipeline'],
|
||||
'image_portrait_enhancement_pipeline':
|
||||
['ImagePortraitEnhancementPipeline'],
|
||||
'image_reid_person_pipeline': ['ImageReidPersonPipeline'],
|
||||
|
||||
103
modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py
Normal file
103
modelscope/pipelines/cv/image_panoptic_segmentation_pipeline.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_segmentation,
|
||||
module_name=Pipelines.image_panoptic_segmentation)
|
||||
class ImagePanopticSegmentationPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image panoptic segmentation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
|
||||
logger.info('panoptic segmentation model, pipeline init')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
from mmdet.datasets.pipelines import Compose
|
||||
from mmcv.parallel import collate, scatter
|
||||
from mmdet.datasets import replace_ImageToTensor
|
||||
|
||||
cfg = self.model.cfg
|
||||
# build the data pipeline
|
||||
|
||||
if isinstance(input, str):
|
||||
# input is str, file names, pipeline loadimagefromfile
|
||||
# collect data
|
||||
data = dict(img_info=dict(filename=input), img_prefix=None)
|
||||
elif isinstance(input, PIL.Image.Image):
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
img = np.array(input.convert('RGB'))
|
||||
# collect data
|
||||
data = dict(img=img)
|
||||
elif isinstance(input, np.ndarray):
|
||||
cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
|
||||
if len(input.shape) == 2:
|
||||
img = cv2.cvtColor(input, cv2.COLOR_GRAY2BGR)
|
||||
else:
|
||||
img = input
|
||||
img = img[:, :, ::-1] # in rgb order
|
||||
# collect data
|
||||
data = dict(img=img)
|
||||
|
||||
else:
|
||||
raise TypeError(f'input should be either str, PIL.Image,'
|
||||
f' np.array, but got {type(input)}')
|
||||
|
||||
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
|
||||
test_pipeline = Compose(cfg.data.test.pipeline)
|
||||
|
||||
data = test_pipeline(data)
|
||||
# copy from mmdet_model collect data
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
]
|
||||
data['img'] = [img.data[0] for img in data['img']]
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
# scatter to specified GPU
|
||||
data = scatter(data, [next(self.model.parameters()).device])[0]
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.inference(input)
|
||||
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# bz=1, tcguo
|
||||
pan_results = inputs[0]['pan_results']
|
||||
INSTANCE_OFFSET = 1000
|
||||
|
||||
ids = np.unique(pan_results)[::-1]
|
||||
legal_indices = ids != self.model.num_classes # for VOID label
|
||||
ids = ids[legal_indices]
|
||||
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)
|
||||
segms = (pan_results[None] == ids[:, None, None])
|
||||
masks = [it.astype(np.int) for it in segms]
|
||||
labels_txt = np.array(self.model.CLASSES)[labels].tolist()
|
||||
|
||||
outputs = {
|
||||
OutputKeys.MASKS: masks,
|
||||
OutputKeys.LABELS: labels_txt,
|
||||
OutputKeys.SCORES: [0.999 for _ in range(len(labels_txt))]
|
||||
}
|
||||
return outputs
|
||||
@@ -134,3 +134,22 @@ def show_video_tracking_result(video_in_path, bboxes, video_save_path):
|
||||
video_writer.write(frame)
|
||||
video_writer.release
|
||||
cap.release()
|
||||
|
||||
|
||||
def panoptic_seg_masks_to_image(masks):
|
||||
draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3])
|
||||
from mmdet.core.visualization.palette import get_palette
|
||||
mask_palette = get_palette('coco', 133)
|
||||
|
||||
from mmdet.core.visualization.image import _get_bias_color
|
||||
taken_colors = set([0, 0, 0])
|
||||
for i, mask in enumerate(masks):
|
||||
color_mask = mask_palette[i]
|
||||
while tuple(color_mask) in taken_colors:
|
||||
color_mask = _get_bias_color(color_mask)
|
||||
taken_colors.add(tuple(color_mask))
|
||||
|
||||
mask = mask.astype(bool)
|
||||
draw_img[mask] = color_mask
|
||||
|
||||
return draw_img
|
||||
|
||||
40
tests/pipelines/test_image_panoptic_segmentation.py
Normal file
40
tests/pipelines/test_image_panoptic_segmentation.py
Normal file
@@ -0,0 +1,40 @@
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import PIL
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.cv.image_utils import panoptic_seg_masks_to_image
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImagePanopticSegmentationTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_panoptic_segmentation(self):
|
||||
input_location = 'data/test/images/image_panoptic_segmentation.jpg'
|
||||
model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan'
|
||||
pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id)
|
||||
result = pan_segmentor(input_location)
|
||||
|
||||
draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS])
|
||||
cv2.imwrite('result.jpg', draw_img)
|
||||
print('print test_image_panoptic_segmentation return success')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_panoptic_segmentation_from_PIL(self):
|
||||
input_location = 'data/test/images/image_panoptic_segmentation.jpg'
|
||||
model_id = 'damo/cv_swinL_panoptic-segmentation_cocopan'
|
||||
pan_segmentor = pipeline(Tasks.image_segmentation, model=model_id)
|
||||
PIL_array = PIL.Image.open(input_location)
|
||||
result = pan_segmentor(PIL_array)
|
||||
|
||||
draw_img = panoptic_seg_masks_to_image(result[OutputKeys.MASKS])
|
||||
cv2.imwrite('result.jpg', draw_img)
|
||||
print('print test_image_panoptic_segmentation from PIL return success')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user