[to #42322933] support salient detection

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9722903
This commit is contained in:
wendi.hwd
2022-08-16 09:15:53 +08:00
committed by yingda.chen
parent 8ce641fd0c
commit d3fac4f5be
11 changed files with 467 additions and 3 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:70ea0c06f9cfe3882253f7175221d47e394ab9c469076ab220e880b17dbcdd02
size 48552

View File

@@ -86,6 +86,7 @@ class Pipelines(object):
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
human_detection = 'resnet18-human-detection'
object_detection = 'vit-object-detection'
salient_detection = 'u2net-salient-detection'
image_classification = 'image-classification'
face_detection = 'resnet-face-detection-scrfd10gkps'
live_category = 'live-category'

View File

@@ -5,4 +5,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
image_colorization, image_denoise, image_instance_segmentation,
image_portrait_enhancement, image_to_image_generation,
image_to_image_translation, object_detection,
product_retrieval_embedding, super_resolution, virual_tryon)
product_retrieval_embedding, salient_detection,
super_resolution, virual_tryon)

View File

@@ -38,7 +38,7 @@ class DetectionModel(TorchModel):
self.model, model_path, map_location='cpu')
self.class_names = checkpoint['meta']['CLASSES']
config.test_pipeline[0].type = 'LoadImageFromWebcam'
self.test_pipeline = Compose(
self.transform_input = Compose(
replace_ImageToTensor(config.test_pipeline))
self.model.cfg = config
self.model.eval()
@@ -56,7 +56,7 @@ class DetectionModel(TorchModel):
from mmcv.parallel import collate, scatter
data = dict(img=image)
data = self.test_pipeline(data)
data = self.transform_input(data)
data = collate([data], samples_per_gpu=1)
data['img_metas'] = [
img_metas.data[0] for img_metas in data['img_metas']

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .salient_model import SalientDetection
else:
_import_structure = {
'salient_model': ['SalientDetection'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1 @@
from .u2net import U2NET

View File

@@ -0,0 +1,300 @@
# Implementation in this file is modifed from source code avaiable via https://github.com/xuebinqin/U-2-Net
import torch
import torch.nn as nn
import torch.nn.functional as F
class REBNCONV(nn.Module):
def __init__(self, in_ch=3, out_ch=3, dirate=1):
super(REBNCONV, self).__init__()
self.conv_s1 = nn.Conv2d(
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
self.bn_s1 = nn.BatchNorm2d(out_ch)
self.relu_s1 = nn.ReLU(inplace=True)
def forward(self, x):
hx = x
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
return xout
def _upsample_like(src, tar):
"""upsample tensor 'src' to have the same spatial size with tensor 'tar'."""
src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
return src
class RSU7(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU7, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx = self.pool5(hx5)
hx6 = self.rebnconv6(hx)
hx7 = self.rebnconv7(hx6)
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
hx6dup = _upsample_like(hx6d, hx5)
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU6(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU6, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx = self.pool4(hx4)
hx5 = self.rebnconv5(hx)
hx6 = self.rebnconv6(hx5)
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU5(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU5, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx = self.pool3(hx3)
hx4 = self.rebnconv4(hx)
hx5 = self.rebnconv5(hx4)
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU4(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx = self.pool1(hx1)
hx2 = self.rebnconv2(hx)
hx = self.pool2(hx2)
hx3 = self.rebnconv3(hx)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
return hx1d + hxin
class RSU4F(nn.Module):
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
super(RSU4F, self).__init__()
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
def forward(self, x):
hx = x
hxin = self.rebnconvin(hx)
hx1 = self.rebnconv1(hxin)
hx2 = self.rebnconv2(hx1)
hx3 = self.rebnconv3(hx2)
hx4 = self.rebnconv4(hx3)
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
return hx1d + hxin
class U2NET(nn.Module):
def __init__(self, in_ch=3, out_ch=1):
super(U2NET, self).__init__()
# encoder
self.stage1 = RSU7(in_ch, 32, 64)
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage2 = RSU6(64, 32, 128)
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage3 = RSU5(128, 64, 256)
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage4 = RSU4(256, 128, 512)
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage5 = RSU4F(512, 256, 512)
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
self.stage6 = RSU4F(512, 256, 512)
# decoder
self.stage5d = RSU4F(1024, 256, 512)
self.stage4d = RSU4(1024, 128, 256)
self.stage3d = RSU5(512, 64, 128)
self.stage2d = RSU6(256, 32, 64)
self.stage1d = RSU7(128, 16, 64)
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
def forward(self, x):
hx = x
hx1 = self.stage1(hx)
hx = self.pool12(hx1)
hx2 = self.stage2(hx)
hx = self.pool23(hx2)
hx3 = self.stage3(hx)
hx = self.pool34(hx3)
hx4 = self.stage4(hx)
hx = self.pool45(hx4)
hx5 = self.stage5(hx)
hx = self.pool56(hx5)
hx6 = self.stage6(hx)
hx6up = _upsample_like(hx6, hx5)
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
hx5dup = _upsample_like(hx5d, hx4)
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
hx4dup = _upsample_like(hx4d, hx3)
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
hx3dup = _upsample_like(hx3d, hx2)
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
hx2dup = _upsample_like(hx2d, hx1)
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
d1 = self.side1(hx1d)
d2 = self.side2(hx2d)
d2 = _upsample_like(d2, d1)
d3 = self.side3(hx3d)
d3 = _upsample_like(d3, d1)
d4 = self.side4(hx4d)
d4 = _upsample_like(d4, d1)
d5 = self.side5(hx5d)
d5 = _upsample_like(d5, d1)
d6 = self.side6(hx6)
d6 = _upsample_like(d6, d1)
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(
d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(
d5), torch.sigmoid(d6)

View File

@@ -0,0 +1,63 @@
import os.path as osp
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from modelscope.metainfo import Models
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from .models import U2NET
@MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection)
class SalientDetection(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""str -- model file root."""
super().__init__(model_dir, *args, **kwargs)
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
self.model = U2NET(3, 1)
checkpoint = torch.load(model_path, map_location='cpu')
self.transform_input = transforms.Compose([
transforms.Resize((320, 320)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
self.model.load_state_dict(checkpoint)
self.model.eval()
def inference(self, data):
"""data is tensor 3 * H * W ---> return tensor H * W ."""
data = data.unsqueeze(0)
if next(self.model.parameters()).is_cuda:
data = data.to(
torch.device([next(self.model.parameters()).device][0]))
with torch.no_grad():
results = self.model(data)
if next(self.model.parameters()).is_cuda:
return results[0][0, 0, :, :].cpu()
return results[0][0, 0, :, :]
def preprocess(self, image):
"""image is numpy."""
data = self.transform_input(Image.fromarray(image))
return data.float()
def postprocess(self, inputs):
"""resize ."""
data = inputs['data']
w = inputs['img_w']
h = inputs['img_h']
data_norm = (data - torch.min(data)) / (
torch.max(data) - torch.min(data))
data_norm_np = (data_norm.numpy() * 255).astype('uint8')
data_norm_rst = cv2.resize(data_norm_np, (w, h))
return data_norm_rst

View File

@@ -10,6 +10,7 @@ if TYPE_CHECKING:
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
from .crowd_counting_pipeline import CrowdCountingPipeline
from .image_detection_pipeline import ImageDetectionPipeline
from .image_salient_detection_pipeline import ImageSalientDetectionPipeline
from .face_detection_pipeline import FaceDetectionPipeline
from .face_image_generation_pipeline import FaceImageGenerationPipeline
from .face_recognition_pipeline import FaceRecognitionPipeline
@@ -43,6 +44,7 @@ else:
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
'crowd_counting_pipeline': ['CrowdCountingPipeline'],
'image_detection_pipeline': ['ImageDetectionPipeline'],
'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'],
'face_detection_pipeline': ['FaceDetectionPipeline'],
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
'face_recognition_pipeline': ['FaceRecognitionPipeline'],

View File

@@ -0,0 +1,47 @@
from typing import Any, Dict
from modelscope.metainfo import Pipelines
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
@PIPELINES.register_module(
Tasks.image_segmentation, module_name=Pipelines.salient_detection)
class ImageSalientDetectionPipeline(Pipeline):
def __init__(self, model: str, **kwargs):
"""
model: model id on modelscope hub.
"""
super().__init__(model=model, auto_collate=False, **kwargs)
def preprocess(self, input: Input) -> Dict[str, Any]:
img = LoadImage.convert_to_ndarray(input)
img_h, img_w, _ = img.shape
img = self.model.preprocess(img)
result = {'img': img, 'img_w': img_w, 'img_h': img_h}
return result
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
outputs = self.model.inference(input['img'])
result = {
'data': outputs,
'img_w': input['img_w'],
'img_h': input['img_h']
}
return result
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
data = self.model.postprocess(inputs)
outputs = {
OutputKeys.SCORES: None,
OutputKeys.LABELS: None,
OutputKeys.MASKS: data
}
return outputs

View File

@@ -0,0 +1,24 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class SalientDetectionTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_salient_detection(self):
input_location = 'data/test/images/image_salient_detection.jpg'
model_id = 'damo/cv_u2net_salient-detection'
salient_detect = pipeline(Tasks.image_segmentation, model=model_id)
result = salient_detect(input_location)
import cv2
# result[OutputKeys.MASKS] is salient map result,other keys are not used
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])
if __name__ == '__main__':
unittest.main()