mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] support salient detection
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9722903
This commit is contained in:
3
data/test/images/image_salient_detection.jpg
Normal file
3
data/test/images/image_salient_detection.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:70ea0c06f9cfe3882253f7175221d47e394ab9c469076ab220e880b17dbcdd02
|
||||
size 48552
|
||||
@@ -86,6 +86,7 @@ class Pipelines(object):
|
||||
body_2d_keypoints = 'hrnetv2w32_body-2d-keypoints_image'
|
||||
human_detection = 'resnet18-human-detection'
|
||||
object_detection = 'vit-object-detection'
|
||||
salient_detection = 'u2net-salient-detection'
|
||||
image_classification = 'image-classification'
|
||||
face_detection = 'resnet-face-detection-scrfd10gkps'
|
||||
live_category = 'live-category'
|
||||
|
||||
@@ -5,4 +5,5 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
image_colorization, image_denoise, image_instance_segmentation,
|
||||
image_portrait_enhancement, image_to_image_generation,
|
||||
image_to_image_translation, object_detection,
|
||||
product_retrieval_embedding, super_resolution, virual_tryon)
|
||||
product_retrieval_embedding, salient_detection,
|
||||
super_resolution, virual_tryon)
|
||||
|
||||
@@ -38,7 +38,7 @@ class DetectionModel(TorchModel):
|
||||
self.model, model_path, map_location='cpu')
|
||||
self.class_names = checkpoint['meta']['CLASSES']
|
||||
config.test_pipeline[0].type = 'LoadImageFromWebcam'
|
||||
self.test_pipeline = Compose(
|
||||
self.transform_input = Compose(
|
||||
replace_ImageToTensor(config.test_pipeline))
|
||||
self.model.cfg = config
|
||||
self.model.eval()
|
||||
@@ -56,7 +56,7 @@ class DetectionModel(TorchModel):
|
||||
|
||||
from mmcv.parallel import collate, scatter
|
||||
data = dict(img=image)
|
||||
data = self.test_pipeline(data)
|
||||
data = self.transform_input(data)
|
||||
data = collate([data], samples_per_gpu=1)
|
||||
data['img_metas'] = [
|
||||
img_metas.data[0] for img_metas in data['img_metas']
|
||||
|
||||
22
modelscope/models/cv/salient_detection/__init__.py
Normal file
22
modelscope/models/cv/salient_detection/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .salient_model import SalientDetection
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'salient_model': ['SalientDetection'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from .u2net import U2NET
|
||||
300
modelscope/models/cv/salient_detection/models/u2net.py
Normal file
300
modelscope/models/cv/salient_detection/models/u2net.py
Normal file
@@ -0,0 +1,300 @@
|
||||
# Implementation in this file is modifed from source code avaiable via https://github.com/xuebinqin/U-2-Net
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class REBNCONV(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, out_ch=3, dirate=1):
|
||||
super(REBNCONV, self).__init__()
|
||||
self.conv_s1 = nn.Conv2d(
|
||||
in_ch, out_ch, 3, padding=1 * dirate, dilation=1 * dirate)
|
||||
self.bn_s1 = nn.BatchNorm2d(out_ch)
|
||||
self.relu_s1 = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
xout = self.relu_s1(self.bn_s1(self.conv_s1(hx)))
|
||||
return xout
|
||||
|
||||
|
||||
def _upsample_like(src, tar):
|
||||
"""upsample tensor 'src' to have the same spatial size with tensor 'tar'."""
|
||||
src = F.upsample(src, size=tar.shape[2:], mode='bilinear')
|
||||
return src
|
||||
|
||||
|
||||
class RSU7(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU7, self).__init__()
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool5 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.rebnconv7 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv6d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx = self.pool5(hx5)
|
||||
hx6 = self.rebnconv6(hx)
|
||||
hx7 = self.rebnconv7(hx6)
|
||||
hx6d = self.rebnconv6d(torch.cat((hx7, hx6), 1))
|
||||
hx6dup = _upsample_like(hx6d, hx5)
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6dup, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class RSU6(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU6, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.rebnconv6 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv5d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx = self.pool4(hx4)
|
||||
hx5 = self.rebnconv5(hx)
|
||||
hx6 = self.rebnconv6(hx5)
|
||||
hx5d = self.rebnconv5d(torch.cat((hx6, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class RSU5(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU5, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.rebnconv5 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv4d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx = self.pool3(hx3)
|
||||
hx4 = self.rebnconv4(hx)
|
||||
hx5 = self.rebnconv5(hx4)
|
||||
hx4d = self.rebnconv4d(torch.cat((hx5, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class RSU4(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=1)
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=1)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx = self.pool1(hx1)
|
||||
hx2 = self.rebnconv2(hx)
|
||||
hx = self.pool2(hx2)
|
||||
hx3 = self.rebnconv3(hx)
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2dup, hx1), 1))
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class RSU4F(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, mid_ch=12, out_ch=3):
|
||||
super(RSU4F, self).__init__()
|
||||
|
||||
self.rebnconvin = REBNCONV(in_ch, out_ch, dirate=1)
|
||||
self.rebnconv1 = REBNCONV(out_ch, mid_ch, dirate=1)
|
||||
self.rebnconv2 = REBNCONV(mid_ch, mid_ch, dirate=2)
|
||||
self.rebnconv3 = REBNCONV(mid_ch, mid_ch, dirate=4)
|
||||
self.rebnconv4 = REBNCONV(mid_ch, mid_ch, dirate=8)
|
||||
self.rebnconv3d = REBNCONV(mid_ch * 2, mid_ch, dirate=4)
|
||||
self.rebnconv2d = REBNCONV(mid_ch * 2, mid_ch, dirate=2)
|
||||
self.rebnconv1d = REBNCONV(mid_ch * 2, out_ch, dirate=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
hxin = self.rebnconvin(hx)
|
||||
hx1 = self.rebnconv1(hxin)
|
||||
hx2 = self.rebnconv2(hx1)
|
||||
hx3 = self.rebnconv3(hx2)
|
||||
hx4 = self.rebnconv4(hx3)
|
||||
hx3d = self.rebnconv3d(torch.cat((hx4, hx3), 1))
|
||||
hx2d = self.rebnconv2d(torch.cat((hx3d, hx2), 1))
|
||||
hx1d = self.rebnconv1d(torch.cat((hx2d, hx1), 1))
|
||||
return hx1d + hxin
|
||||
|
||||
|
||||
class U2NET(nn.Module):
|
||||
|
||||
def __init__(self, in_ch=3, out_ch=1):
|
||||
super(U2NET, self).__init__()
|
||||
|
||||
# encoder
|
||||
self.stage1 = RSU7(in_ch, 32, 64)
|
||||
self.pool12 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.stage2 = RSU6(64, 32, 128)
|
||||
self.pool23 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.stage3 = RSU5(128, 64, 256)
|
||||
self.pool34 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.stage4 = RSU4(256, 128, 512)
|
||||
self.pool45 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.stage5 = RSU4F(512, 256, 512)
|
||||
self.pool56 = nn.MaxPool2d(2, stride=2, ceil_mode=True)
|
||||
self.stage6 = RSU4F(512, 256, 512)
|
||||
# decoder
|
||||
self.stage5d = RSU4F(1024, 256, 512)
|
||||
self.stage4d = RSU4(1024, 128, 256)
|
||||
self.stage3d = RSU5(512, 64, 128)
|
||||
self.stage2d = RSU6(256, 32, 64)
|
||||
self.stage1d = RSU7(128, 16, 64)
|
||||
self.side1 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side2 = nn.Conv2d(64, out_ch, 3, padding=1)
|
||||
self.side3 = nn.Conv2d(128, out_ch, 3, padding=1)
|
||||
self.side4 = nn.Conv2d(256, out_ch, 3, padding=1)
|
||||
self.side5 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.side6 = nn.Conv2d(512, out_ch, 3, padding=1)
|
||||
self.outconv = nn.Conv2d(6 * out_ch, out_ch, 1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
hx = x
|
||||
hx1 = self.stage1(hx)
|
||||
hx = self.pool12(hx1)
|
||||
hx2 = self.stage2(hx)
|
||||
hx = self.pool23(hx2)
|
||||
hx3 = self.stage3(hx)
|
||||
hx = self.pool34(hx3)
|
||||
hx4 = self.stage4(hx)
|
||||
hx = self.pool45(hx4)
|
||||
hx5 = self.stage5(hx)
|
||||
hx = self.pool56(hx5)
|
||||
hx6 = self.stage6(hx)
|
||||
hx6up = _upsample_like(hx6, hx5)
|
||||
|
||||
hx5d = self.stage5d(torch.cat((hx6up, hx5), 1))
|
||||
hx5dup = _upsample_like(hx5d, hx4)
|
||||
hx4d = self.stage4d(torch.cat((hx5dup, hx4), 1))
|
||||
hx4dup = _upsample_like(hx4d, hx3)
|
||||
hx3d = self.stage3d(torch.cat((hx4dup, hx3), 1))
|
||||
hx3dup = _upsample_like(hx3d, hx2)
|
||||
hx2d = self.stage2d(torch.cat((hx3dup, hx2), 1))
|
||||
hx2dup = _upsample_like(hx2d, hx1)
|
||||
hx1d = self.stage1d(torch.cat((hx2dup, hx1), 1))
|
||||
d1 = self.side1(hx1d)
|
||||
d2 = self.side2(hx2d)
|
||||
d2 = _upsample_like(d2, d1)
|
||||
d3 = self.side3(hx3d)
|
||||
d3 = _upsample_like(d3, d1)
|
||||
d4 = self.side4(hx4d)
|
||||
d4 = _upsample_like(d4, d1)
|
||||
d5 = self.side5(hx5d)
|
||||
d5 = _upsample_like(d5, d1)
|
||||
d6 = self.side6(hx6)
|
||||
d6 = _upsample_like(d6, d1)
|
||||
d0 = self.outconv(torch.cat((d1, d2, d3, d4, d5, d6), 1))
|
||||
return torch.sigmoid(d0), torch.sigmoid(d1), torch.sigmoid(
|
||||
d2), torch.sigmoid(d3), torch.sigmoid(d4), torch.sigmoid(
|
||||
d5), torch.sigmoid(d6)
|
||||
63
modelscope/models/cv/salient_detection/salient_model.py
Normal file
63
modelscope/models/cv/salient_detection/salient_model.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import os.path as osp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .models import U2NET
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.image_segmentation, module_name=Models.detection)
|
||||
class SalientDetection(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""str -- model file root."""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
model_path = osp.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
self.model = U2NET(3, 1)
|
||||
checkpoint = torch.load(model_path, map_location='cpu')
|
||||
self.transform_input = transforms.Compose([
|
||||
transforms.Resize((320, 320)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
])
|
||||
self.model.load_state_dict(checkpoint)
|
||||
self.model.eval()
|
||||
|
||||
def inference(self, data):
|
||||
"""data is tensor 3 * H * W ---> return tensor H * W ."""
|
||||
data = data.unsqueeze(0)
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
data = data.to(
|
||||
torch.device([next(self.model.parameters()).device][0]))
|
||||
|
||||
with torch.no_grad():
|
||||
results = self.model(data)
|
||||
|
||||
if next(self.model.parameters()).is_cuda:
|
||||
return results[0][0, 0, :, :].cpu()
|
||||
return results[0][0, 0, :, :]
|
||||
|
||||
def preprocess(self, image):
|
||||
"""image is numpy."""
|
||||
data = self.transform_input(Image.fromarray(image))
|
||||
return data.float()
|
||||
|
||||
def postprocess(self, inputs):
|
||||
"""resize ."""
|
||||
data = inputs['data']
|
||||
w = inputs['img_w']
|
||||
h = inputs['img_h']
|
||||
data_norm = (data - torch.min(data)) / (
|
||||
torch.max(data) - torch.min(data))
|
||||
data_norm_np = (data_norm.numpy() * 255).astype('uint8')
|
||||
data_norm_rst = cv2.resize(data_norm_np, (w, h))
|
||||
|
||||
return data_norm_rst
|
||||
@@ -10,6 +10,7 @@ if TYPE_CHECKING:
|
||||
from .cmdssl_video_embedding_pipeline import CMDSSLVideoEmbeddingPipeline
|
||||
from .crowd_counting_pipeline import CrowdCountingPipeline
|
||||
from .image_detection_pipeline import ImageDetectionPipeline
|
||||
from .image_salient_detection_pipeline import ImageSalientDetectionPipeline
|
||||
from .face_detection_pipeline import FaceDetectionPipeline
|
||||
from .face_image_generation_pipeline import FaceImageGenerationPipeline
|
||||
from .face_recognition_pipeline import FaceRecognitionPipeline
|
||||
@@ -43,6 +44,7 @@ else:
|
||||
'cmdssl_video_embedding_pipeline': ['CMDSSLVideoEmbeddingPipeline'],
|
||||
'crowd_counting_pipeline': ['CrowdCountingPipeline'],
|
||||
'image_detection_pipeline': ['ImageDetectionPipeline'],
|
||||
'image_salient_detection_pipeline': ['ImageSalientDetectionPipeline'],
|
||||
'face_detection_pipeline': ['FaceDetectionPipeline'],
|
||||
'face_image_generation_pipeline': ['FaceImageGenerationPipeline'],
|
||||
'face_recognition_pipeline': ['FaceRecognitionPipeline'],
|
||||
|
||||
47
modelscope/pipelines/cv/image_salient_detection_pipeline.py
Normal file
47
modelscope/pipelines/cv/image_salient_detection_pipeline.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Any, Dict
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_segmentation, module_name=Pipelines.salient_detection)
|
||||
class ImageSalientDetectionPipeline(Pipeline):
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, auto_collate=False, **kwargs)
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
|
||||
img = LoadImage.convert_to_ndarray(input)
|
||||
img_h, img_w, _ = img.shape
|
||||
img = self.model.preprocess(img)
|
||||
result = {'img': img, 'img_w': img_w, 'img_h': img_h}
|
||||
return result
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
outputs = self.model.inference(input['img'])
|
||||
result = {
|
||||
'data': outputs,
|
||||
'img_w': input['img_w'],
|
||||
'img_h': input['img_h']
|
||||
}
|
||||
return result
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
data = self.model.postprocess(inputs)
|
||||
outputs = {
|
||||
OutputKeys.SCORES: None,
|
||||
OutputKeys.LABELS: None,
|
||||
OutputKeys.MASKS: data
|
||||
}
|
||||
return outputs
|
||||
24
tests/pipelines/test_salient_detection.py
Normal file
24
tests/pipelines/test_salient_detection.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import unittest
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class SalientDetectionTest(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_salient_detection(self):
|
||||
input_location = 'data/test/images/image_salient_detection.jpg'
|
||||
model_id = 'damo/cv_u2net_salient-detection'
|
||||
salient_detect = pipeline(Tasks.image_segmentation, model=model_id)
|
||||
result = salient_detect(input_location)
|
||||
import cv2
|
||||
# result[OutputKeys.MASKS] is salient map result,other keys are not used
|
||||
cv2.imwrite(input_location + '_salient.jpg', result[OutputKeys.MASKS])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user