add image skychange

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10947701
This commit is contained in:
hannah.yh
2022-12-21 17:40:46 +08:00
committed by wenmeng.zwm
parent 02edb1ab15
commit b36bb72869
20 changed files with 2388 additions and 0 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:260cd09f340b86007dd471cba742f82bae0fb5cfd4b8d87265bff5ad2c2c857f
size 652482

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:679c86d5a82c9c1c4866b5e16b98a2128a57e3ea60f77d56e5f0fe79ab7d746f
size 505993

View File

@@ -57,6 +57,7 @@ class Models(object):
face_emotion = 'face-emotion'
product_segmentation = 'product-segmentation'
image_body_reshaping = 'image-body-reshaping'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
@@ -243,6 +244,7 @@ class Pipelines(object):
product_segmentation = 'product-segmentation'
image_body_reshaping = 'flow-based-body-reshaping'
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
video_object_segmentation = 'video-object-segmentation'
@@ -389,6 +391,7 @@ class Preprocessors(object):
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'
object_detection_scrfd = 'object-detection-scrfd'
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
# nlp preprocessor
sen_sim_tokenizer = 'sen-sim-tokenizer'

View File

@@ -0,0 +1,22 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .skychange_model import ImageSkychange
from .preprocessor import ImageSkyChangePreprocessor
else:
_import_structure = {'skychange_model': ['ImageSkychange']}
_import_structure = {'preprocessor': ['ImageSkyChangePreprocessor']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,245 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numbers
import pdb
from typing import Any, Dict, Union
import cv2
import json
import numpy as np
import torch
from torchvision import transforms
from modelscope.metainfo import Preprocessors
from modelscope.preprocessors import Preprocessor
from modelscope.preprocessors.builder import PREPROCESSORS
from modelscope.preprocessors.image import LoadImage
from modelscope.utils.constant import Fields, ModeKeys
_cv2_pad_to_str = {
'constant': cv2.BORDER_CONSTANT,
'edge': cv2.BORDER_REPLICATE,
'reflect': cv2.BORDER_REFLECT_101,
'symmetric': cv2.BORDER_REFLECT,
}
@PREPROCESSORS.register_module(
Fields.cv, module_name=Preprocessors.image_sky_change_preprocessor)
class ImageSkyChangePreprocessor(Preprocessor):
def __init__(self,
model_dir: str = None,
mode: str = ModeKeys.INFERENCE,
coarse_model_width=640,
coarse_model_height=640,
refine_model_width=1280,
refine_model_height=1280,
mean_vec=[0.485, 0.456, 0.406],
std_vec=[0.229, 0.224, 0.225],
*args,
**kwargs):
"""
Args:
model_dir (str): model directory to initialize some resource.
mode: The mode for the preprocessor.
coarse_model_width: required width of input tensor of coarse model.
coarse_model_height: required height of input tensor of coarse model.
refine_model_width: required width of input tensor of refine model.
refine_model_height: required height of input tensor of refine model.
mean_vec: mean of dataset(for transforms.Normalize), default is mean of Imagenet dataset.
std_vec: standard deviation of dataset(for transforms.Normalize), default is std of Imagenet dataset.
"""
super().__init__(mode)
# set preprocessor info
self.coarse_input_size = [coarse_model_width, coarse_model_height]
self.refine_input_size = [refine_model_width, refine_model_height]
self.normalize = transforms.Normalize(mean=mean_vec, std=std_vec)
def __call__(self, data: Union[str, Dict], **kwargs) -> Dict[str, Any]:
"""process the raw input data
Args:
data (dict): data dict containing following info:
sky_image, scene_image
example:
```python
{
"sky_image": "xxx.jpg" # sky_image path(str)
"scene_image": "xxx.jpg", # scene_image path(str)
}
```
Returns:
Dict[str, Any]: the preprocessed data
{
"sky_image": the preprocessed sky image(origin size)
"sky_image_refine": the preprocessed resized sky image
"scene_image": the preprocessed scene image(origin size)
"scene_image_refine": the preprocessed resized scene image
"img_metas": informations of preprocessed images, e.g. origin shape, pad information, resized shape.
}
"""
if 'sky_image' not in data.keys():
raise Exception('sky_image not in input data')
if 'scene_image' not in data.keys():
raise Exception('scene_image not in input data')
if isinstance(data['sky_image'], str):
sky_image = LoadImage.convert_to_ndarray(data['sky_image'])
sky_image = sky_image.astype(np.uint8) # RGB
sky_image = cv2.cvtColor(sky_image, cv2.COLOR_RGB2BGR) # BGR
if sky_image is not None:
sky_image = self.check_image(sky_image)
else:
raise Exception('sky_image is None')
else:
raise Exception('sky_image(path of sky image) is not valid')
if isinstance(data['scene_image'], str):
scene_image = LoadImage.convert_to_ndarray(data['scene_image'])
scene_image = scene_image.astype(np.uint8) # RGB
scene_image = cv2.cvtColor(scene_image, cv2.COLOR_RGB2BGR) # BGR
if scene_image is not None:
scene_image = self.check_image(scene_image)
else:
raise Exception('scene_image is None')
else:
raise Exception('scene_image(path of scene image) is not valid')
data = {}
sky_image_refine, sky_img_metas = self.process_single_img(sky_image)
scene_image_refine, scene_img_metas = self.process_single_img(
scene_image)
data['sky_image'] = sky_image
data['sky_image_refine'] = sky_image_refine
data['scene_image'] = scene_image
data['scene_image_refine'] = scene_image_refine
data['img_metas'] = {
'sky_img_metas': sky_img_metas,
'scene_img_metas': scene_img_metas,
'input_size': {
'coarse_input_size': self.coarse_input_size,
'refine_input_size': self.refine_input_size
}
}
return data
def process_single_img(self, img):
img_metas = {}
img_metas['ori_shape'] = img.shape[0:2] # img: (origin_h, origin_w, 3)
img, pad_direction = get_refine_input(img, self.refine_input_size)
img = image_transform(
img, self.normalize) # torch.Size([3, refine_net_h, refine_net_w])
img = img.unsqueeze(0)
img_metas['pad_direction'] = pad_direction
img_metas['refine_shape'] = img.shape[
2:] # torch.Size([1, 3, refine_net_h, refine_net_w])
return img, img_metas
def check_image(self, input_img):
whole_temp_shape = input_img.shape
if len(whole_temp_shape) == 2:
input_img = np.stack([input_img, input_img, input_img], axis=2)
elif whole_temp_shape[2] == 1:
input_img = np.concatenate([input_img, input_img, input_img],
axis=2)
elif whole_temp_shape[2] == 4:
input_img = input_img[:, :,
0:3] * 1.0 * input_img[:, :,
3:4] * 1.0 / 255.0
return input_img
def get_refine_input(mat, refine_input_size):
# maxDimMatch: resize
mat = max_dim_match(mat, refine_input_size)
# pad image to refine net input size
mat, pad_direction = center_pad_image_withwh(mat, refine_input_size, 0)
return mat, pad_direction
def max_dim_match(image, refine_model_size):
h, w, c = np.shape(image)
resize_w, resize_h = refine_model_size
if h != resize_h or w != resize_w:
h_scale = float(resize_h) / h
w_scale = float(resize_w) / w
resize_scale = min(w_scale, h_scale)
new_h = int(h * resize_scale + 0.5)
new_w = int(w * resize_scale + 0.5)
image = cv2.resize(
image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
return image
def center_pad_image_withwh(image,
crop_size,
padvalue,
padding_mode='constant'):
pad_image = image
h, w = image.shape[0], image.shape[1]
pad_h = max(crop_size[1] - h, 0)
pad_w = max(crop_size[0] - w, 0)
pad_direction = (0, 0, 0, 0)
if pad_h > 0 or pad_w > 0:
half_w = int(pad_w / 2 + 0.5)
half_h = int(pad_h / 2 + 0.5)
pad_direction = (half_w, half_h, pad_w - half_w, pad_h - half_h)
pad_image = pad(
image, pad_direction, padvalue, padding_mode=padding_mode)
return pad_image, pad_direction
def pad(img, padding, fill=0, padding_mode='constant'):
if not is_numpy_image(img):
raise TypeError('img should be numpy ndarray. Got {}'.format(
type(img)))
if not isinstance(padding,
(numbers.Number, tuple, list)) or len(padding) != 4:
raise TypeError('Got inappropriate padding arg')
pad_left = padding[0]
pad_top = padding[1]
pad_right = padding[2]
pad_bottom = padding[3]
shape_len = len(img.shape)
if shape_len == 2:
return cv2.copyMakeBorder(
img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=_cv2_pad_to_str[padding_mode],
value=fill,
)
elif shape_len == 3 and img.shape[2] == 1:
return cv2.copyMakeBorder(
img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=_cv2_pad_to_str[padding_mode],
value=fill,
)[:, :, np.newaxis]
else:
return cv2.copyMakeBorder(
img,
top=pad_top,
bottom=pad_bottom,
left=pad_left,
right=pad_right,
borderType=_cv2_pad_to_str[padding_mode],
value=fill,
)
def image_transform(img, normalize):
img = img[:, :, ::-1] # BGR-->RGB to pil format
img = img.transpose((2, 0, 1)) # h,w,c --> c,h,w
img = img.astype(np.float32) / 255
img = normalize(torch.from_numpy(img.copy()))
return img
def is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})

View File

@@ -0,0 +1,118 @@
# The implementation is adopted from ASPP made publicly available under the MIT License License
# at https://github.com/jfzhang95/pytorch-deeplab-xception
import torch
import torch.nn.functional as F
from torch import nn
BatchNorm2d = nn.BatchNorm2d
class ASPPModule(nn.Module):
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
BatchNorm):
super(ASPPModule, self).__init__()
self.atrous_conv = nn.Conv2d(
inplanes,
planes,
kernel_size=kernel_size,
stride=1,
padding=padding,
dilation=dilation,
bias=False)
self.bn = BatchNorm(planes)
self.relu = nn.ReLU()
self._init_weight()
def forward(self, x):
x = self.atrous_conv(x)
x = self.bn(x)
return self.relu(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# this aspp is official version
# copy from :https://github.com/jfzhang95/pytorch-deeplab-xception
class ASPP(nn.Module):
def __init__(self, inplanes, outplanes, dilations, drop_rate=0.1):
super(ASPP, self).__init__()
self.aspp1 = ASPPModule(
inplanes,
outplanes,
1,
padding=0,
dilation=dilations[0],
BatchNorm=BatchNorm2d)
self.aspp2 = ASPPModule(
inplanes,
outplanes,
3,
padding=dilations[1],
dilation=dilations[1],
BatchNorm=BatchNorm2d)
self.aspp3 = ASPPModule(
inplanes,
outplanes,
3,
padding=dilations[2],
dilation=dilations[2],
BatchNorm=BatchNorm2d)
self.aspp4 = ASPPModule(
inplanes,
outplanes,
3,
padding=dilations[3],
dilation=dilations[3],
BatchNorm=BatchNorm2d)
self.global_avg_pool = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False),
BatchNorm2d(outplanes), nn.ReLU())
self.conv1 = nn.Conv2d(outplanes * 5, outplanes, 1, bias=False)
self.bn1 = BatchNorm2d(outplanes)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(drop_rate)
self._init_weight()
def forward(self, x): # [1, 256, 320, 320]
x1 = self.aspp1(x) # [1, 128, 160, 160]
x2 = self.aspp2(x) # [1, 128, 160, 160]
x3 = self.aspp3(x) # [1, 128, 160, 160]
x4 = self.aspp4(x) # [1, 128, 160, 160]
x5 = self.global_avg_pool(x) # b,c,h,w [1, 128, 1, 1]
x5 = F.interpolate(
x5, size=x4.size()[2:], mode='bilinear',
align_corners=True) # [1, 128, 160, 160]
x = torch.cat((x1, x2, x3, x4, x5), dim=1) # [1, 640, 160, 160]
x = self.conv1(x) # [1, 640, 160, 160]
x = self.bn1(x)
x = self.relu(x)
return self.dropout(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_normal_(m.weight)
elif isinstance(m, BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()

View File

@@ -0,0 +1,620 @@
# The implementation is adopted from HRNet, made publicly available under the MIT License License
# at https://github.com/HRNet/HRNet-Semantic-Segmentation
import logging
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
BatchNorm2d = nn.BatchNorm2d
BN_MOMENTUM = 0.1
logger = logging.getLogger(__name__)
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
planes,
planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
self.conv3 = nn.Conv2d(
planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
class HighResolutionModule(nn.Module):
def __init__(self,
num_branches,
blocks,
num_blocks,
num_inchannels,
num_channels,
fuse_method,
multi_scale_output=True):
super(HighResolutionModule, self).__init__()
self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
num_channels)
self.num_inchannels = num_inchannels
self.fuse_method = fuse_method
self.num_branches = num_branches
self.multi_scale_output = multi_scale_output
self.branches = self._make_branches(num_branches, blocks, num_blocks,
num_channels)
self.fuse_layers = self._make_fuse_layers()
self.relu = nn.ReLU(inplace=True)
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
num_channels):
if num_branches != len(num_blocks):
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
num_branches, len(num_blocks))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_channels):
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
num_branches, len(num_channels))
logger.error(error_msg)
raise ValueError(error_msg)
if num_branches != len(num_inchannels):
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
num_branches, len(num_inchannels))
logger.error(error_msg)
raise ValueError(error_msg)
def _make_one_branch(self,
branch_index,
block,
num_blocks,
num_channels,
stride=1):
downsample = None
if stride != 1 or self.num_inchannels[
branch_index] != num_channels[branch_index] * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
self.num_inchannels[branch_index],
num_channels[branch_index] * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
BatchNorm2d(
num_channels[branch_index] * block.expansion,
momentum=BN_MOMENTUM),
)
layers = []
layers.append(
block(self.num_inchannels[branch_index],
num_channels[branch_index], stride, downsample))
self.num_inchannels[
branch_index] = num_channels[branch_index] * block.expansion
for i in range(1, num_blocks[branch_index]):
layers.append(
block(self.num_inchannels[branch_index],
num_channels[branch_index]))
return nn.Sequential(*layers)
def _make_branches(self, num_branches, block, num_blocks, num_channels):
branches = []
for i in range(num_branches):
branches.append(
self._make_one_branch(i, block, num_blocks, num_channels))
return nn.ModuleList(branches)
def _make_fuse_layers(self):
if self.num_branches == 1:
return None
num_branches = self.num_branches
num_inchannels = self.num_inchannels # tuple
fuse_layers = []
for i in range(num_branches if self.multi_scale_output else 1):
fuse_layer = []
for j in range(num_branches):
if j > i:
fuse_layer.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_inchannels[i],
1,
1,
0,
bias=False),
BatchNorm2d(
num_inchannels[i], momentum=BN_MOMENTUM)))
elif j == i:
fuse_layer.append(None)
else:
conv3x3s = []
for k in range(i - j):
if k == i - j - 1:
num_outchannels_conv3x3 = num_inchannels[i]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False),
BatchNorm2d(
num_outchannels_conv3x3,
momentum=BN_MOMENTUM)))
else:
num_outchannels_conv3x3 = num_inchannels[j]
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
num_inchannels[j],
num_outchannels_conv3x3,
3,
2,
1,
bias=False),
BatchNorm2d(
num_outchannels_conv3x3,
momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
fuse_layer.append(nn.Sequential(*conv3x3s))
fuse_layers.append(nn.ModuleList(fuse_layer))
return nn.ModuleList(fuse_layers)
def get_num_inchannels(self):
return self.num_inchannels
def forward(self, x):
if self.num_branches == 1:
return [self.branches[0](x[0])]
for i in range(self.num_branches):
x[i] = self.branches[i](x[i])
x_fuse = []
for i in range(len(self.fuse_layers)):
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
for j in range(1, self.num_branches):
if i == j:
y = y + x[j]
elif j > i:
width_output = x[i].shape[-1]
height_output = x[i].shape[-2]
y = y + F.interpolate(
self.fuse_layers[i][j](x[j]),
size=(height_output, width_output),
mode='bilinear',
align_corners=True)
else:
y = y + self.fuse_layers[i][j](x[j])
x_fuse.append(self.relu(y))
return x_fuse
model_w18v1 = {
'STAGE1': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 1,
'BLOCK': 'BOTTLENECK',
'NUM_BLOCKS': (1),
'NUM_CHANNELS': (32),
'FUSE_METHOD': 'SUM'
},
'STAGE2': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 2,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2),
'NUM_CHANNELS': (16, 32),
'FUSE_METHOD': 'SUM'
},
'STAGE3': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 3,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2, 2),
'NUM_CHANNELS': (16, 32, 64),
'FUSE_METHOD': 'SUM'
},
'STAGE4': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 4,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2, 2, 2),
'NUM_CHANNELS': (16, 32, 64, 128),
'FUSE_METHOD': 'SUM'
},
'FINAL_CONV_KERNEL': 1
}
model_w18v2 = {
'STAGE1': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 1,
'BLOCK': 'BOTTLENECK',
'NUM_BLOCKS': (2),
'NUM_CHANNELS': (64),
'FUSE_METHOD': 'SUM'
},
'STAGE2': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 2,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2),
'NUM_CHANNELS': (18, 36),
'FUSE_METHOD': 'SUM'
},
'STAGE3': {
'NUM_MODULES': 3,
'NUM_BRANCHES': 3,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2, 2),
'NUM_CHANNELS': (18, 36, 72),
'FUSE_METHOD': 'SUM'
},
'STAGE4': {
'NUM_MODULES': 2,
'NUM_BRANCHES': 4,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (2, 2, 2, 2),
'NUM_CHANNELS': (18, 36, 72, 144),
'FUSE_METHOD': 'SUM'
},
'FINAL_CONV_KERNEL': 1
}
model_w48 = {
'STAGE1': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 1,
'BLOCK': 'BOTTLENECK',
'NUM_BLOCKS': (4),
'NUM_CHANNELS': (64),
'FUSE_METHOD': 'SUM'
},
'STAGE2': {
'NUM_MODULES': 1,
'NUM_BRANCHES': 2,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (4, 4),
'NUM_CHANNELS': (48, 96),
'FUSE_METHOD': 'SUM'
},
'STAGE3': {
'NUM_MODULES': 4,
'NUM_BRANCHES': 3,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (4, 4, 4),
'NUM_CHANNELS': (48, 96, 192),
'FUSE_METHOD': 'SUM'
},
'STAGE4': {
'NUM_MODULES': 3,
'NUM_BRANCHES': 4,
'BLOCK': 'BASIC',
'NUM_BLOCKS': (4, 4, 4, 4),
'NUM_CHANNELS': (48, 96, 192, 384),
'FUSE_METHOD': 'SUM'
},
'FINAL_CONV_KERNEL': 1
}
model_version_dict = {}
model_version_dict['w48'] = model_w48
model_version_dict['w18v1'] = model_w18v1
model_version_dict['w18v2'] = model_w18v2
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
class HrnetBackBone(nn.Module):
def __init__(self, **kwargs):
super(HrnetBackBone, self).__init__()
assert 'version' in kwargs, 'hrnet not exist model version'
extra = model_version_dict[kwargs['version']]
# stem net
self.conv1 = nn.Conv2d(
3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.conv2 = nn.Conv2d(
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
self.relu = nn.ReLU(inplace=True)
self.stage1_cfg = extra['STAGE1']
num_channels = self.stage1_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage1_cfg['BLOCK']]
num_blocks = self.stage1_cfg['NUM_BLOCKS']
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
stage1_out_channel = block.expansion * num_channels
self.stage2_cfg = extra['STAGE2']
num_channels = self.stage2_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage2_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition1 = self._make_transition_layer([stage1_out_channel],
num_channels)
self.stage2, pre_stage_channels = self._make_stage(
self.stage2_cfg, num_channels)
self.stage3_cfg = extra['STAGE3']
num_channels = self.stage3_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage3_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition2 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage3, pre_stage_channels = self._make_stage(
self.stage3_cfg, num_channels)
self.stage4_cfg = extra['STAGE4']
num_channels = self.stage4_cfg['NUM_CHANNELS']
block = blocks_dict[self.stage4_cfg['BLOCK']]
num_channels = [
num_channels[i] * block.expansion
for i in range(len(num_channels))
]
self.transition3 = self._make_transition_layer(pre_stage_channels,
num_channels)
self.stage4, pre_stage_channels = self._make_stage(
self.stage4_cfg, num_channels, multi_scale_output=True)
self.backbone_last_inp_channels = np.int(np.sum(pre_stage_channels))
def _make_transition_layer(self, num_channels_pre_layer,
num_channels_cur_layer):
num_branches_cur = len(num_channels_cur_layer)
num_branches_pre = len(num_channels_pre_layer)
transition_layers = []
for i in range(num_branches_cur):
if i < num_branches_pre:
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
transition_layers.append(
nn.Sequential(
nn.Conv2d(
num_channels_pre_layer[i],
num_channels_cur_layer[i],
3,
1,
1,
bias=False),
BatchNorm2d(
num_channels_cur_layer[i],
momentum=BN_MOMENTUM), nn.ReLU(inplace=True)))
else:
transition_layers.append(None)
else:
conv3x3s = []
for j in range(i + 1 - num_branches_pre):
inchannels = num_channels_pre_layer[-1]
outchannels = num_channels_cur_layer[
i] if j == i - num_branches_pre else inchannels
conv3x3s.append(
nn.Sequential(
nn.Conv2d(
inchannels, outchannels, 3, 2, 1, bias=False),
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
nn.ReLU(inplace=True)))
transition_layers.append(nn.Sequential(*conv3x3s))
return nn.ModuleList(transition_layers)
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
)
layers = []
layers.append(block(inplanes, planes, stride, downsample))
inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(inplanes, planes))
return nn.Sequential(*layers)
def _make_stage(self,
layer_config,
num_inchannels,
multi_scale_output=True):
num_modules = layer_config['NUM_MODULES']
num_branches = layer_config['NUM_BRANCHES']
num_blocks = layer_config['NUM_BLOCKS']
num_channels = layer_config['NUM_CHANNELS']
block = blocks_dict[layer_config['BLOCK']]
fuse_method = layer_config['FUSE_METHOD']
modules = []
for i in range(num_modules):
# multi_scale_output is only used last module
if not multi_scale_output and i == num_modules - 1:
reset_multi_scale_output = False
else:
reset_multi_scale_output = True
modules.append(
HighResolutionModule(num_branches, block, num_blocks,
num_inchannels, num_channels, fuse_method,
reset_multi_scale_output))
num_inchannels = modules[-1].get_num_inchannels()
return nn.Sequential(*modules), num_inchannels
def _backbone_forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.layer1(x)
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x)
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(
x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x2 = F.interpolate(
x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x3 = F.interpolate(
x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x = torch.cat([x[0], x1, x2, x3], 1)
return x
def init_weights(self, url, cache_file=''):
pretrained_dict = load_state(url, model_dir=cache_file)
model_dict = self.state_dict()
model_len = len(model_dict)
pretrain_len = len(pretrained_dict)
common_dict = {}
valid_layer_num = 0
for k, v in pretrained_dict.items():
if k in model_dict:
common_dict[k] = v
valid_layer_num += 1
print('*' * 50)
print('Model Param Num:{} Pretrained Param Num:{} '
'Commmon Num:{}'.format(model_len, pretrain_len,
valid_layer_num))
print('-' * 50)
print('Model Extra Param Names:\n\t{}'.format(
set(model_dict) - set(pretrained_dict)))
print('-' * 50)
print('Pretrained Extra Param Names:\n\t{}'.format(
set(pretrained_dict) - set(model_dict)))
print('*' * 50)
model_dict.update(common_dict)
self.load_state_dict(model_dict)

View File

@@ -0,0 +1,510 @@
# Part of the implementation is borrowed and modified from HRNet,
# publicly available under the MIT License License at https://github.com/HRNet/HRNet-Semantic-Segmentation
from __future__ import absolute_import, division, print_function
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .BlockModules import ASPP
from .hrnet_backnone import BatchNorm2d, HrnetBackBone, blocks_dict
ALIGN_CORNERS = True
BN_MOMENTUM = 0.1
class ModuleHelper:
@staticmethod
def BNReLU(num_features, bn_type=None, **kwargs):
return nn.Sequential(BatchNorm2d(num_features, **kwargs), nn.ReLU())
@staticmethod
def BatchNorm2d(*args, **kwargs):
return BatchNorm2d
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_planes,
out_planes,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class SpatialGatherModule(nn.Module):
"""
Aggregate the context features according to the initial
predicted probability distribution.
Employ the soft-weighted method to aggregate the context.
"""
def __init__(self, cls_num=0, scale=1):
super(SpatialGatherModule, self).__init__()
self.cls_num = cls_num
self.scale = scale
def forward(self, feats, probs):
batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(
2), probs.size(3)
probs = probs.view(batch_size, c, -1)
feats = feats.view(batch_size, feats.size(1), -1)
feats = feats.permute(0, 2, 1) # batch x hw x c
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
ocr_context = torch.matmul(probs, feats) # batch x k x c
ocr_context = ocr_context.permute(0, 2,
1).unsqueeze(3) # batch x c x k x 1
return ocr_context
class ObjectAttentionBlock(nn.Module):
'''
The basic implementation for object context block
Input:
N X C X H X W
Parameters:
in_channels : the dimension of the input feature map
key_channels : the dimension after the key/query transform
scale : choose the scale to downsample the input feature maps (save memory cost)
bn_type : specify the bn type
Return:
N X C X H X W
'''
def __init__(self, in_channels, key_channels, scale=1, bn_type=None):
super(ObjectAttentionBlock, self).__init__()
self.scale = scale
self.in_channels = in_channels
self.key_channels = key_channels
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
self.f_pixel = nn.Sequential(
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
nn.Conv2d(
in_channels=self.key_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_object = nn.Sequential(
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
nn.Conv2d(
in_channels=self.key_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_down = nn.Sequential(
nn.Conv2d(
in_channels=self.in_channels,
out_channels=self.key_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
)
self.f_up = nn.Sequential(
nn.Conv2d(
in_channels=self.key_channels,
out_channels=self.in_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
)
def forward(self, x, proxy):
batch_size, h, w = x.size(0), x.size(2), x.size(3)
if self.scale > 1:
x = self.pool(x)
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
query = query.permute(0, 2, 1)
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
value = value.permute(0, 2, 1)
sim_map = torch.matmul(query, key)
sim_map = (self.key_channels**-.5) * sim_map
sim_map = F.softmax(sim_map, dim=-1)
# add bg context ...
context = torch.matmul(sim_map, value)
context = context.permute(0, 2, 1).contiguous()
context = context.view(batch_size, self.key_channels, *x.size()[2:])
context = self.f_up(context)
if self.scale > 1:
context = F.interpolate(
input=context,
size=(h, w),
mode='bilinear',
align_corners=ALIGN_CORNERS)
return context
class ObjectAttentionBlock2D(ObjectAttentionBlock):
def __init__(self, in_channels, key_channels, scale=1, bn_type=None):
super(ObjectAttentionBlock2D, self).__init__(
in_channels, key_channels, scale, bn_type=bn_type)
class SpatialOCRModule(nn.Module):
"""
Implementation of the OCR module:
We aggregate the global object representation to update the representation for each pixel.
"""
def __init__(self,
in_channels,
key_channels,
out_channels,
scale=1,
dropout=0.1,
bn_type=None):
super(SpatialOCRModule, self).__init__()
self.object_context_block = ObjectAttentionBlock2D(
in_channels, key_channels, scale, bn_type)
_in_channels = 2 * in_channels
self.conv_bn_dropout = nn.Sequential(
nn.Conv2d(
_in_channels,
out_channels,
kernel_size=1,
padding=0,
bias=False),
ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
nn.Dropout2d(dropout))
def forward(self, feats, proxy_feats):
context = self.object_context_block(feats, proxy_feats)
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
return output
class HrnetSuperAndOcr(HrnetBackBone):
def __init__(self, **kwargs):
super(HrnetSuperAndOcr, self).__init__(**kwargs)
if 'architecture' not in kwargs:
raise Exception('HrnetSuperAndOcr not exist architecture param!')
self.architecture = kwargs['architecture']
if 'class_num' not in kwargs:
raise Exception('HrnetSuperAndOcr not exist class_num param!')
self.class_num = kwargs['class_num']
if 'ocr' not in kwargs:
raise Exception('HrnetSuperAndOcr not exist ocr param!')
ocr_mid_channels = kwargs['ocr']['mid_channels']
ocr_key_channels = kwargs['ocr']['key_channels']
dropout_rate = kwargs['ocr']['dropout_rate']
scale = kwargs['ocr']['scale']
if 'super_param' not in kwargs:
raise Exception('HrnetSuperAndOcr not exist super_param param!')
self.super_dict = kwargs['super_param']
self.is_export_onnx = False
self.is_export_full_onnx = False
self.is_contain_tail = True if 'tail_param' in kwargs else False
if self.is_contain_tail:
self.stage_tail_dict = kwargs['tail_param']
num_channels = self.stage_tail_dict['NUM_CHANNELS'][0]
block = blocks_dict[self.stage_tail_dict['BLOCK']]
num_blocks = self.stage_tail_dict['NUM_BLOCKS'][0]
self.stage_tail = self._make_layer(block,
self.backbone_last_inp_channels,
num_channels, num_blocks)
last_inp_channels = block.expansion * num_channels
else:
last_inp_channels = self.backbone_last_inp_channels
self.is_contain_aspp = True if 'aspp' in kwargs else False
if self.architecture == 'hrnet_super_ocr':
self.is_ocr_first = False
num_channels = [64, last_inp_channels]
self.stage_super, super_stage_channels = self._make_stage(
self.super_dict, num_channels)
last_inp_channels = np.int(np.sum(super_stage_channels))
if self.is_contain_aspp:
aspp_param = kwargs['aspp']
self.aspp_layer = ASPP(
inplanes=last_inp_channels,
outplanes=aspp_param['outplanes'],
dilations=aspp_param['dilations'],
drop_rate=aspp_param['drop_rate'])
last_inp_channels = aspp_param['outplanes']
self.aux_head = nn.Sequential(
nn.Conv2d(
last_inp_channels,
last_inp_channels,
kernel_size=1,
stride=1,
padding=0), BatchNorm2d(last_inp_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
last_inp_channels,
self.class_num,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.conv3x3_ocr = nn.Sequential(
nn.Conv2d(
last_inp_channels,
ocr_mid_channels,
kernel_size=3,
stride=1,
padding=1),
BatchNorm2d(ocr_mid_channels),
nn.ReLU(inplace=True),
)
self.ocr_gather_head = SpatialGatherModule(self.class_num)
self.ocr_distri_head = SpatialOCRModule(
in_channels=ocr_mid_channels,
key_channels=ocr_key_channels,
out_channels=ocr_mid_channels,
scale=scale,
dropout=dropout_rate,
)
self.cls_head = nn.Sequential(
nn.Conv2d(
ocr_mid_channels,
ocr_mid_channels,
kernel_size=1,
stride=1,
padding=0), BatchNorm2d(ocr_mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
ocr_mid_channels,
self.class_num,
kernel_size=1,
stride=1,
padding=0,
bias=True))
else:
self.is_ocr_first = True
if self.is_contain_aspp:
aspp_param = kwargs['aspp']
self.aspp_layer = ASPP(
inplanes=last_inp_channels,
outplanes=aspp_param['outplanes'],
dilations=aspp_param['dilations'],
drop_rate=aspp_param['drop_rate'])
last_inp_channels = aspp_param['outplanes']
self.aux_head = nn.Sequential(
nn.Conv2d(
last_inp_channels,
last_inp_channels,
kernel_size=1,
stride=1,
padding=0), BatchNorm2d(last_inp_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
last_inp_channels,
self.class_num,
kernel_size=1,
stride=1,
padding=0,
bias=True))
self.conv3x3_ocr = nn.Sequential(
nn.Conv2d(
last_inp_channels,
ocr_mid_channels,
kernel_size=3,
stride=1,
padding=1),
BatchNorm2d(ocr_mid_channels),
nn.ReLU(inplace=True),
)
self.ocr_gather_head = SpatialGatherModule(self.class_num)
self.ocr_distri_head = SpatialOCRModule(
in_channels=ocr_mid_channels,
key_channels=ocr_key_channels,
out_channels=ocr_mid_channels,
scale=scale,
dropout=dropout_rate,
)
num_channels = [64, ocr_mid_channels]
self.stage_super, super_stage_channels = self._make_stage(
self.super_dict, num_channels)
last_inp_channels = np.int(np.sum(super_stage_channels))
self.cls_head = nn.Sequential(
nn.Conv2d(
last_inp_channels,
last_inp_channels,
kernel_size=1,
stride=1,
padding=0), BatchNorm2d(last_inp_channels),
nn.ReLU(inplace=True),
nn.Conv2d(
last_inp_channels,
self.class_num,
kernel_size=1,
stride=1,
padding=0,
bias=True))
def forward(self, x):
if self.is_export_onnx:
x = x.permute(0, 3, 1, 2)
raw_h, raw_w = x.size(2), x.size(3)
if self.is_export_full_onnx:
raw_h, raw_w = x.size(2), x.size(3)
x = self.conv1(x)
x = self.bn1(x) # 5, 64, 320, 320
x_stem = self.relu(x)
x = self.conv2(x_stem)
x = self.bn2(x)
x = self.relu(x) # 5, 64, 160, 160
x = self.layer1(x) # 5, 256=64*4, 160, 160
x_list = []
for i in range(self.stage2_cfg['NUM_BRANCHES']):
if self.transition1[i] is not None:
x_list.append(self.transition1[i](x))
else:
x_list.append(x) # [[5, 18, 160, 160],[5, 36, 80, 80]]
y_list = self.stage2(x_list)
x_list = []
for i in range(self.stage3_cfg['NUM_BRANCHES']):
if self.transition2[i] is not None:
x_list.append(self.transition2[i](y_list[-1]))
else:
x_list.append(y_list[i])
y_list = self.stage3(x_list)
x_list = []
for i in range(self.stage4_cfg['NUM_BRANCHES']):
if self.transition3[i] is not None:
x_list.append(self.transition3[i](y_list[-1]))
else:
x_list.append(y_list[i])
x = self.stage4(x_list)
# Upsampling
x0_h, x0_w = x[0].size(2), x[0].size(3)
x1 = F.interpolate(
x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x2 = F.interpolate(
x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
x3 = F.interpolate(
x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
feats = torch.cat([x[0], x1, x2, x3], 1)
if self.is_contain_tail:
feats = self.stage_tail(feats)
if self.is_ocr_first:
if self.is_contain_aspp:
feats = self.aspp_layer(feats)
# compute contrast feature
out_aux = self.aux_head(feats)
feats = self.conv3x3_ocr(feats)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
feats = [x_stem, feats] # 320*320 2X
x_super = self.stage_super(feats)
xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3)
x_super1 = F.interpolate(
x_super[1],
size=(xsuper_h, xsuper_w),
mode='bilinear',
align_corners=True)
x_super = torch.cat([x_super[0], x_super1], 1)
out = self.cls_head(x_super)
else:
x_super = [x_stem, feats] # 320*320 2X, 160*160 4X
x_super = self.stage_super(x_super)
xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3)
x_super1 = F.interpolate(
x_super[1],
size=(xsuper_h, xsuper_w),
mode='bilinear',
align_corners=True)
x_super = torch.cat([x_super[0], x_super1], 1)
if self.is_contain_aspp:
x_super = self.aspp_layer(x_super)
out_aux = self.aux_head(x_super)
feats = self.conv3x3_ocr(x_super)
context = self.ocr_gather_head(feats, out_aux)
feats = self.ocr_distri_head(feats, context)
out = self.cls_head(feats)
if self.is_export_onnx or self.is_export_full_onnx:
x_class = F.interpolate(
out, size=(raw_h, raw_w), mode='bilinear', align_corners=True)
x_class = torch.softmax(x_class, dim=1)
_, x_class = torch.max(x_class, dim=1, keepdim=True)
x_class = x_class.float()
return x_class
else:
out_aux_seg = [
out_aux, out
] # out_aux: 5, 2, 160, 160(HRNet origin res); out: 5, 2, 320, 320(HRNet res+tail+aspp+ocr)
return out_aux_seg
def get_seg_model(cfg, **kwargs):
model = HrnetSuperAndOcr(cfg, **kwargs)
model.init_weights(cfg.MODEL.PRETRAINED)
return model

View File

@@ -0,0 +1,229 @@
# Copyright 2021-2022 The Alibaba Vision Team Authors. All rights reserved.
import torch
import torch.nn as nn
from .BlockModules import ASPP
class Conv2DBatchNormRelu(nn.Module):
def __init__(self,
in_channels,
n_filters,
k_size,
stride,
padding,
bias=True,
dilation=1,
with_bn=True,
with_relu=True):
super(Conv2DBatchNormRelu, self).__init__()
conv_mod = nn.Conv2d(
int(in_channels),
int(n_filters),
kernel_size=k_size,
padding=padding,
stride=stride,
bias=bias,
dilation=dilation,
)
if with_bn:
if with_relu:
self.cbr_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)),
nn.ReLU(inplace=True))
else:
self.cbr_unit = nn.Sequential(conv_mod,
nn.BatchNorm2d(int(n_filters)))
else:
if with_relu:
self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
else:
self.cbr_unit = nn.Sequential(conv_mod)
def forward(self, inputs):
outputs = self.cbr_unit(inputs)
return outputs
class SegnetDown2(nn.Module):
def __init__(self, in_size, out_size):
super(SegnetDown2, self).__init__()
self.conv1 = Conv2DBatchNormRelu(
in_size, out_size, k_size=3, stride=1, padding=1)
self.conv2 = Conv2DBatchNormRelu(
out_size, out_size, k_size=3, stride=1, padding=1)
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
unpooled_shape = outputs.size()
outputs, indices = self.maxpool_with_argmax(outputs)
return outputs, indices, unpooled_shape
class SegnetDown3(nn.Module):
def __init__(self, in_size, out_size):
super(SegnetDown3, self).__init__()
self.conv1 = Conv2DBatchNormRelu(
in_size, out_size, k_size=3, stride=1, padding=1)
self.conv2 = Conv2DBatchNormRelu(
out_size, out_size, k_size=3, stride=1, padding=1)
self.conv3 = Conv2DBatchNormRelu(
out_size, out_size, k_size=3, stride=1, padding=1)
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
def forward(self, inputs):
outputs = self.conv1(inputs)
outputs = self.conv2(outputs)
outputs = self.conv3(outputs)
unpooled_shape = outputs.size()
outputs, indices = self.maxpool_with_argmax(outputs)
return outputs, indices, unpooled_shape
class SegnetUp1(nn.Module):
def __init__(self, in_size, out_size):
super(SegnetUp1, self).__init__()
self.unpool = nn.MaxUnpool2d(2, 2)
self.conv = Conv2DBatchNormRelu(
in_size, out_size, k_size=5, stride=1, padding=2, with_relu=False)
def forward(self, inputs, indices, output_shape):
outputs = self.unpool(
input=inputs, indices=indices, output_size=output_shape)
outputs = self.conv(outputs)
return outputs
class Unet(nn.Module):
def __init__(self,
n_classes=2,
in_channels=4,
is_unpooling=True,
pretrain=True,
**kwargs):
super(Unet, self).__init__()
print('Load Unet')
self.in_channels = in_channels
self.is_unpooling = is_unpooling
self.pretrain = pretrain
self.is_contain_aspp = True if 'aspp' in kwargs else False
if self.is_contain_aspp:
aspp_param = kwargs['aspp']
self.aspp_layer = ASPP(
inplanes=128,
outplanes=aspp_param['outplanes'],
dilations=aspp_param['dilations'],
drop_rate=aspp_param['drop_rate'])
self.aspp_channels = aspp_param['outplanes']
self.down1 = SegnetDown2(self.in_channels, 64)
self.down2 = SegnetDown2(64, 128)
self.down3 = SegnetDown3(128, 256)
self.down4 = SegnetDown3(256, 512)
self.down5 = SegnetDown3(512, 512)
self.up5 = SegnetUp1(512, 512)
self.up4 = SegnetUp1(512, 256)
self.up3 = SegnetUp1(256, 128)
if self.is_contain_aspp:
self.conv_1x1_aspp = Conv2DBatchNormRelu(
128 + self.aspp_channels,
128,
k_size=1,
stride=1,
padding=0,
with_relu=False)
self.up2 = SegnetUp1(128, 64)
self.up1 = SegnetUp1(64, n_classes)
self.sigmoid = nn.Sigmoid()
if self.pretrain:
import torchvision.models as models
vgg16 = models.vgg16()
self.init_vgg16_params(vgg16)
def forward(self, inputs): # [1, 4, 1346, 1152] [2, 4, 1280, 1280]
# inputs: [N, 4, 320, 320]
# outputs, indices, unpooled_shape
down1, indices_1, unpool_shape1 = self.down1(
inputs) # [1, 64, 673, 576] [2, 64, 640, 640]
down2, indices_2, unpool_shape2 = self.down2(
down1) # [1, 128, 336, 288] [2, 128, 320, 320]
down3, indices_3, unpool_shape3 = self.down3(
down2) # [1, 256, 168, 144] [2, 256, 160, 160]
torch.cuda.empty_cache()
if self.is_contain_aspp: # batchsize can not be 1
aspp_output = self.aspp_layer(down2)
down4, indices_4, unpool_shape4 = self.down4(
down3) # [1, 512, 84, 72] [2, 512, 80, 80]
down5, indices_5, unpool_shape5 = self.down5(
down4) # [1, 512, 42, 36] [2, 512, 80, 80]
torch.cuda.empty_cache()
up5 = self.up5(down5, indices_5,
unpool_shape5) # [1, 512, 84, 72] [2, 512, 80, 80]
up4 = self.up4(up5, indices_4,
unpool_shape4) # [1, 256, 168, 144] [2, 256, 160, 160]
torch.cuda.empty_cache()
up3 = self.up3(
up4, indices_3,
unpool_shape3) # [1, 128, 336, 288] [2, 128, 320, 320]
if self.is_contain_aspp:
up3 = torch.cat([up3, aspp_output], 1) # [2, 256, 320, 320]
up3 = self.conv_1x1_aspp(up3) # [2, 128, 320, 320]
up2 = self.up2(
up3, indices_2,
unpool_shape2) # [1, 64, 673, 576] indices_2: [2, 128, 320, 320]
up1 = self.up1(up2, indices_1, unpool_shape1) # [1, 1, 1346, 1152]
x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320]
x = self.sigmoid(x)
return x # [2, 1280, 1280]
def init_vgg16_params(self, vgg16):
blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]
features = list(vgg16.features.children())
vgg_layers = []
for _layer in features:
if isinstance(_layer, nn.Conv2d):
vgg_layers.append(_layer)
merged_layers = []
for idx, conv_block in enumerate(blocks):
if idx < 2:
units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
else:
units = [
conv_block.conv1.cbr_unit,
conv_block.conv2.cbr_unit,
conv_block.conv3.cbr_unit,
]
for _unit in units:
for _layer in _unit:
if isinstance(_layer, nn.Conv2d):
merged_layers.append(_layer)
assert len(vgg_layers) == len(merged_layers)
for l1, l2 in zip(vgg_layers, merged_layers):
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
if l1.weight.size() == l2.weight.size() and l1.bias.size(
) == l2.bias.size():
l2.weight.data = l1.weight.data
l2.bias.data = l1.bias.data

View File

@@ -0,0 +1,310 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numbers
import os
import pdb
from collections import deque
import cv2
import json
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
torch.backends.cudnn.enabled = True
IMAGE_MAX_DIM = 3000
IMAGE_MIN_DIM = 50
IMAGE_MAX_RATIO = 10.0
IMAGE_BLENDER_MASK_RESIZE_SCALE = 10.0
IMAGE_BLENDER_INNER_RECT_MAX_DIM = 256
IMAGE_BLENDER_DILATE_KERNEL_SIZE = 7
IMAGE_BLENDER_VALID_MASK_THRESHOLD = 100
IMAGE_BLENDER_MIN_VALID_SKY_AREA = 100
IMAGE_BLENDER_MIN_RESIZE_DIM = 10
IMAGE_BLENDER_BLUR_KERNEL_SIZE = 5
def extract_sky_image(in_sky_image, in_sky_mask):
scale = 1.0
resize_mask = in_sky_mask.copy()
rows, cols = resize_mask.shape[0:2]
# src size: (512, 640), target size: (256,256), then scale to size (256, 320)
if (rows > IMAGE_BLENDER_INNER_RECT_MAX_DIM
or cols > IMAGE_BLENDER_INNER_RECT_MAX_DIM):
height_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(rows)
width_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(cols)
scale = height_scale if height_scale > width_scale else width_scale
new_size = (max(int(cols * scale), 1), max(int(rows * scale),
1)) # w, h
resize_mask = cv2.resize(resize_mask, new_size, cv2.INTER_LINEAR)
kernelSize = max(3, int(scale * IMAGE_BLENDER_DILATE_KERNEL_SIZE + 0.5))
element = cv2.getStructuringElement(cv2.MORPH_RECT,
(kernelSize, kernelSize))
resize_mask = cv2.morphologyEx(resize_mask, cv2.MORPH_CLOSE, element)
max_inner_rect, area = get_max_inner_rect(
resize_mask, IMAGE_BLENDER_VALID_MASK_THRESHOLD, True)
if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
raise Exception(
'[extractSkyImage]failed!! Valid sky region is too small')
scale = 1.0 / scale
# max_inner_rect: left top(x,y), right bottome(x,y); raw_inner_rect:left top x,y,w(of bbox),h(of bbox)
raw_inner_rect = scale_rect(max_inner_rect, in_sky_mask, scale)
out_sky_image = in_sky_image[raw_inner_rect[1]:raw_inner_rect[1]
+ raw_inner_rect[3] + 1,
raw_inner_rect[0]:raw_inner_rect[0]
+ raw_inner_rect[2] + 1, ].copy()
return out_sky_image
def blend(scene_image, scene_mask, sky_image, sky_mask, inBlendLevelNum=10):
if torch.cuda.is_available():
scene_image = scene_image.cpu().numpy()
sky_image = sky_image.cpu().numpy()
else:
scene_image = scene_image.numpy()
sky_image = sky_image.numpy()
sky_image_h, sky_image_w = sky_image.shape[0:2]
sky_mask_h, sky_mask_w = sky_mask.shape[0:2]
scene_image_h, scene_image_w = scene_image.shape[0:2]
scene_mask_h, scene_mask_w = scene_mask.shape[0:2]
if sky_image_h != sky_mask_h or sky_image_w != sky_mask_w:
raise Exception(
'[blend]failed!! sky_image shape not equal with sky_image_mask shape'
)
if scene_image_h != scene_mask_h or scene_image_w != scene_mask_w:
raise Exception(
'[blend]failed!! scene_image shape not equal with scene_image_mask shape'
)
valid_sky_image = extract_sky_image(sky_image, sky_mask)
out_blend_image = blend_merge(scene_image, scene_mask, valid_sky_image,
inBlendLevelNum)
return out_blend_image
def get_max_inner_rect(in_image_mask, in_alpha_threshold, is_bigger_valid):
res = 0
row, col = in_image_mask.shape[0:2]
i0, j0, i1, j1 = 0, 0, 0, 0
height = [0] * (col + 1)
for i in range(0, row):
s = deque()
for j in range(0, col + 1):
if j < col:
if is_bigger_valid:
height[j] = (
height[j]
+ 1 if in_image_mask[i, j] > in_alpha_threshold else 0)
else:
height[j] = (
height[j] + 1
if in_image_mask[i, j] <= in_alpha_threshold else 0)
while len(s) != 0 and height[s[-1]] >= height[j]:
cur = s[-1]
s.pop()
_h = height[cur]
_w = j if len(s) == 0 else j - s[-1] - 1
curArea = _h * _w
if curArea > res:
res = curArea
i1 = i
i0 = i1 - _h + 1
j1 = j - 1
j0 = j1 - _w + 1
s.append(j)
out_rect = (
j0,
i0,
j1 - j0 + 1,
i1 - i0 + 1,
)
return out_rect, res
def scale_rect(in_rect, in_image_size, in_scale):
tlX = int(in_rect[0] * in_scale + 0.5)
tlY = int(in_rect[1] * in_scale + 0.5)
in_image_size_h, in_image_size_w = in_image_size.shape[0:2]
brX = min(int(in_rect[2] * in_scale + 0.5), in_image_size_w)
brY = min(int(in_rect[3] * in_scale + 0.5), in_image_size_h)
out_rect = (tlX, tlY, brX - tlX, brY - tlY)
return out_rect
def get_fast_valid_rect(in_mask, in_threshold=0):
# mask: np.array [0~1]
in_mask = in_mask > in_threshold
locations = cv2.findNonZero(in_mask.astype(np.uint8))
output_rect = cv2.boundingRect(locations) # x,y,w,h
return output_rect
def min_size_match(in_image, in_min_size, type=cv2.INTER_LINEAR):
resize_image = in_image.copy()
width, height = in_min_size
resize_img_height, resize_img_width = in_image.shape[0:2]
height_scale = height / resize_img_height
widht_scale = width / resize_img_width
scale = height_scale if height_scale > widht_scale else widht_scale
new_size = (
max(int(resize_img_width * scale + 0.5), 1),
max(int(resize_img_height * scale + 0.5), 1),
)
resize_image = cv2.resize(resize_image, new_size, 0, 0, type)
return resize_image
def center_crop(in_image, in_size):
in_size_w, in_size_h = in_size
in_image_h, in_image_w = in_image.shape[0:2]
half_height = (in_image_h - in_size_h) // 2
half_width = (in_image_w - in_size_w) // 2
out_crop_image = in_image.copy()
out_crop_image = out_crop_image[half_height:half_height + in_size_h,
half_width:half_width + in_size_w]
return out_crop_image
def safe_roi_pad(in_pad_image, in_rect, out_base_image):
in_rect_x, in_rect_y, in_rect_w, in_rect_h = in_rect
if in_rect_x < 0 or in_rect_y < 0 or in_rect_w <= 0 or in_rect_h <= 0:
raise Exception('[safe_roi_pad] Failed!! x,y,w,h of rect are illegal')
if in_rect_w != in_pad_image.shape[1] or in_rect_h != in_pad_image.shape[0]:
raise Exception('[safe_roi_pad] Failed!!')
if (in_rect_x + in_rect_w > out_base_image.shape[1]
or in_rect_y + in_rect_h > out_base_image.shape[0]):
raise Exception('[safe_roi_pad] Failed!!')
out_base_image[in_rect_y:in_rect_y + in_rect_h,
in_rect_x:in_rect_x + in_rect_w] = in_pad_image
def merge_image(in_base_image, in_merge_image, in_merge_mask, in_point):
if in_merge_image.shape[0:2] != in_merge_mask.shape[0:2]:
raise Exception(
'[merge_image] Failed!! in_merge_image.shape != in_merge_mask.shape!!'
)
in_point_x, in_point_y = in_point
in_merge_image_rows, in_merge_image_cols = in_merge_image.shape[0:2]
in_base_image_rows, in_base_image_cols = in_base_image.shape[0:2]
if (in_point_x + in_merge_image_cols > in_base_image_cols
or in_point_y + in_merge_image_rows > in_base_image_rows):
raise Exception(
'[merge_image] Failed!! merge_image:image rect not in image')
base_roi_image = in_base_image[in_point_y:in_point_y + in_merge_image_rows,
in_point_x:in_point_x
+ in_merge_image_cols, ]
merge_image = in_merge_image.copy()
merge_alpha = in_merge_mask.copy()
base_roi_image = np.float32(base_roi_image)
merge_alpha = np.repeat(merge_alpha[:, :, np.newaxis], 3, axis=2)
merge_alpha = merge_alpha / 255.0
base_roi_image = (
1 - merge_alpha) * base_roi_image + merge_alpha * merge_image
base_roi_image = np.clip(base_roi_image, 0, 255)
base_roi_image = base_roi_image.astype('uint8')
roi_rect = (in_point_x, in_point_y, in_merge_image_cols,
in_merge_image_rows)
safe_roi_pad(base_roi_image, roi_rect, in_base_image)
return in_base_image
def blend_merge(in_scene_image,
in_scene_mask,
in_valid_sky_image,
inBlendLevelNum=5):
scene_sky_rect = get_fast_valid_rect(in_scene_mask, 1)
area = scene_sky_rect[2] * scene_sky_rect[3]
if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
raise Exception(
'[blend_merge] Failed!! Scene Image Valid sky region is too small')
valid_sky_image = min_size_match(in_valid_sky_image, scene_sky_rect[2:])
valid_sky_image = center_crop(valid_sky_image, scene_sky_rect[2:])
# resizeSceneMask
sky_size = (
max(
int(in_scene_mask.shape[1] * IMAGE_BLENDER_MASK_RESIZE_SCALE
+ 0.5),
IMAGE_BLENDER_MIN_RESIZE_DIM,
),
max(
int(in_scene_mask.shape[0] * IMAGE_BLENDER_MASK_RESIZE_SCALE
+ 0.5),
IMAGE_BLENDER_MIN_RESIZE_DIM,
),
)
resize_scene_mask = cv2.resize(in_scene_mask, sky_size, cv2.INTER_LINEAR)
resize_scene_mask = cv2.blur(
resize_scene_mask,
(IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE),
)
element = cv2.getStructuringElement(
cv2.MORPH_RECT,
(IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE))
sky_mask = cv2.dilate(resize_scene_mask, element) # enlarge sky region
scene_mask = cv2.erode(resize_scene_mask, element) # enlarge scene region
scene_mask = 255 - scene_mask
sky_mask = cv2.resize(sky_mask, in_scene_mask.shape[0:2][::-1])
scene_mask = cv2.resize(scene_mask, in_scene_mask.shape[0:2][::-1])
x, y, w, h = scene_sky_rect
valid_sky_mask = sky_mask[y:y + h, x:x + w]
pano_sky_image = in_scene_image.copy()
pano_sky_image = merge_image(pano_sky_image, valid_sky_image,
valid_sky_mask, scene_sky_rect[0:2])
blend_images = []
blend_images.append(in_scene_image)
blend_images.append(pano_sky_image)
blend_masks = []
blend_masks.append(scene_mask.astype(np.uint8))
blend_masks.append(sky_mask.astype(np.uint8))
panorama_rect = (0, 0, in_scene_image.shape[1], in_scene_image.shape[0])
blender = cv2.detail_MultiBandBlender(1, inBlendLevelNum)
blender.prepare(panorama_rect)
for i in range(0, len(blend_images)):
blender.feed(blend_images[i], blend_masks[i], (0, 0))
pano_mask = (
np.ones(
(in_scene_image.shape[1], in_scene_image.shape[0]), dtype='uint8')
* 255)
out_blend_image = np.zeros_like(in_scene_image)
result = blender.blend(out_blend_image, pano_mask)
return result[0]

View File

@@ -0,0 +1,199 @@
import math
import os
import pdb
import time
from collections import OrderedDict
from typing import Any, Dict, List, Union
import cv2
import json
import torch
import torch.nn.functional as F
from modelscope.metainfo import Models
from modelscope.models import Model
from modelscope.models.base import TorchModel
from modelscope.models.builder import MODELS
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
from .ptsemseg.hrnet_super_and_ocr import HrnetSuperAndOcr
from .ptsemseg.unet import Unet
from .skychange import blend
logger = get_logger()
@MODELS.register_module(
Tasks.image_skychange, module_name=Models.image_skychange)
class ImageSkychange(TorchModel):
def __init__(self, model_dir, refine_cfg, coarse_cfg, *args, **kwargs):
"""
Args:
model_dir (str): model directory to initialize some resource.
refine_cfg: configuration of refine model.
coarse_cfg: configuration of coarse model.
"""
super().__init__(model_dir=model_dir, *args, **kwargs)
if torch.cuda.is_available():
self.device = torch.device('cuda')
logger.info('Use GPU: {}'.format(self.device))
else:
self.device = torch.device('cpu')
logger.info('Use CPU: {}'.format(self.device))
coarse_model_path = '{}/{}'.format(model_dir,
ModelFile.TORCH_MODEL_FILE)
refine_model_path = '{}/{}'.format(model_dir,
'unet_sky_matting_final_model.pkl')
logger.info(
'####################### load refine models ################################'
)
self.refine_model = Unet(**refine_cfg['Model'])
self.load_model(self.refine_model, refine_model_path)
self.refine_model.eval()
logger.info(
'####################### load refine models done ############################'
)
logger.info(
'####################### load coarse models ################################'
)
self.coarse_model = HrnetSuperAndOcr(**coarse_cfg['Model'])
self.load_model(self.coarse_model, coarse_model_path)
self.coarse_model.eval()
logger.info(
'####################### load coarse models done ############################'
)
def load_model(self, seg_model, input_model_path):
if not os.path.isfile(input_model_path):
logger.error(
'[checkModelPath]:model path dose not exits!!! model Path:'
+ input_model_path)
raise Exception('[checkModelPath]:model path dose not exits!')
if torch.cuda.is_available():
checkpoint = torch.load(input_model_path)
model_state = self.convert_state_dict(checkpoint['model_state'])
seg_model.load_state_dict(model_state)
seg_model.to(self.device)
else:
checkpoint = torch.load(input_model_path, map_location='cpu')
model_state = self.convert_state_dict(checkpoint['model_state'])
seg_model.load_state_dict(model_state)
def convert_state_dict(self, state_dict):
"""Converts a state dict saved from a dataParallel module to normal
module state_dict inplace
:param state_dict is the loaded DataParallel model_state
"""
if not next(iter(state_dict)).startswith('module.'):
return state_dict # abort if dict is not a DataParallel model_state
new_state_dict = OrderedDict()
split_index = 0
for cur_key, cur_value in state_dict.items():
if cur_key.startswith('module.model'):
split_index = 13
elif cur_key.startswith('module'):
split_index = 7
break
for k, v in state_dict.items():
name = k[split_index:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
def forward(
self,
sky_image: torch.Tensor,
sky_image_refine: torch.Tensor,
scene_image: torch.Tensor,
scene_image_refine: torch.Tensor,
img_metas: Dict[str, Any],
):
"""
Args:
sky_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h', w'].
sky_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w].
scene_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h, w].
scene_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w].
img_metas (`Dict[str, Any]`): image meta info.
Return:
`IMAGE: shape is [h, w, 3] (0~255)`
"""
start = time.time()
sky_img_metas, scene_img_metas, input_size = img_metas[
'sky_img_metas'], img_metas['scene_img_metas'], img_metas[
'input_size']
sky_mask = self.inference_mask(sky_image_refine, sky_img_metas,
input_size)
scene_mask = self.inference_mask(scene_image_refine, scene_img_metas,
input_size)
end = time.time()
logger.info(
'Time of inferencing mask of sky and scene images:{}'.format(
end - start))
start = time.time()
scene_mask = scene_mask * 255
sky_mask = sky_mask * 255
res = blend(scene_image, scene_mask, sky_image, sky_mask)
end = time.time()
logger.info('Time of blending: {}'.format(end - start))
return res
@torch.no_grad()
def inference_mask(self, img, img_metas, input_size):
self.eval()
raw_h, raw_w = img_metas['ori_shape']
pad_direction = img_metas['pad_direction']
coarse_input_size = input_size['coarse_input_size']
refine_input_size = input_size['refine_input_size']
h, w = img_metas['refine_shape']
resize_images = F.interpolate(
img, coarse_input_size, mode='bilinear', align_corners=True)
# get coarse result
pred_scores = self.coarse_model(resize_images)
if isinstance(pred_scores, (tuple, list)):
pred_scores = pred_scores[-1]
score = F.interpolate(
input=pred_scores,
size=refine_input_size,
mode='bilinear',
align_corners=True,
)
_, coarse_pred = torch.max(score, dim=1) # [B, h, w]
coarse_pred = coarse_pred.unsqueeze(1).type(img.dtype)
img = torch.cat([img, coarse_pred], dim=1) # [B, c=4, h, w]
del resize_images
del pred_scores
del score
del coarse_pred
torch.cuda.empty_cache()
cur_scores = self.refine_model(img)
del img
torch.cuda.empty_cache()
cur_scores = torch.clip(cur_scores, 0, 1)
cur_scores = cur_scores.detach().cpu().numpy()[0]
# resize if cur_scores shape are not compatible with origin image shape
ph, pw = cur_scores.shape
if ph != h or pw != w:
cur_scores = F.interpolate(
input=cur_scores,
size=(h, w),
mode='nearest',
align_corners=True)
# unpad to get valid area and resize to origin size
valid_cur_pred = cur_scores[pad_direction[1]:refine_input_size[0]
- pad_direction[3],
pad_direction[0]:refine_input_size[1]
- pad_direction[2], ]
valid_cur_pred = cv2.resize(valid_cur_pred, (raw_w, raw_h))
del cur_scores
torch.cuda.empty_cache()
print('get refine mask done')
return valid_cur_pred

View File

@@ -836,6 +836,11 @@ TASK_OUTPUTS = {
# }
Tasks.product_segmentation: [OutputKeys.MASKS],
# image_skychange result for a single sample
# {
# "output_img": np.ndarray with shape [height, width, 3]
# }
Tasks.image_skychange: [OutputKeys.OUTPUT_IMG],
# {
# 'scores': [0.1, 0.2, 0.3, ...]
# }

View File

@@ -101,6 +101,10 @@ TASK_INPUTS = {
'img': InputType.IMAGE,
'mask': InputType.IMAGE,
},
Tasks.image_skychange: {
'sky_image': InputType.IMAGE,
'scene_image': InputType.IMAGE,
},
# image generation task result for a single image
Tasks.image_to_image_generation:

View File

@@ -223,6 +223,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_swin-t_referring_video-object-segmentation'),
Tasks.video_summarization: (Pipelines.video_summarization,
'damo/cv_googlenet_pgl-video-summarization'),
Tasks.image_skychange: (Pipelines.image_skychange,
'damo/cv_hrnetocr_skychange'),
Tasks.translation_evaluation:
(Pipelines.translation_evaluation,
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),

View File

@@ -67,6 +67,7 @@ if TYPE_CHECKING:
from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
from .image_skychange_pipeline import ImageSkychangePipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
else:
@@ -155,6 +156,7 @@ else:
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
],
'image_skychange_pipeline': ['ImageSkychangePipeline'],
'video_object_segmentation_pipeline': [
'VideoObjectSegmentationPipeline'
],

View File

@@ -0,0 +1,63 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import pdb
import time
from typing import Any, Dict, Union
import cv2
import numpy as np
import PIL
from modelscope.metainfo import Pipelines
from modelscope.models.cv.image_skychange import ImageSkyChangePreprocessor
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Model, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
@PIPELINES.register_module(
Tasks.image_skychange, module_name=Pipelines.image_skychange)
class ImageSkychangePipeline(Pipeline):
""" Image Sky Change Pipeline. Given two images(sky_image and scene_image),
pipeline will replace the sky style of sky_image with the sky style of scene_image.
Example:
```python
>>> from modelscope.pipelines import pipeline
>>> detector = pipeline('image-skychange', 'damo/cv_hrnetocr_skychange')
>>> detector({
'sky_image': 'sky_image.jpg', # sky_image path (str)
'scene_image': 'scene_image.jpg', # scene_image path (str)
})
{
"output_img": [H * W * 3] 0~255, we can use cv2.imwrite to save output_img as an image.
}
>>> #
```
"""
def __init__(self, model: str, **kwargs):
"""
use `model` to create a image sky change pipeline for image editing
Args:
model (`str` or `Model`): model_id on modelscope hub
preprocessor(`Preprocessor`, *optional*, defaults to None): `ImageSkyChangePreprocessor`.
"""
super().__init__(model=model, **kwargs)
if not isinstance(self.model, Model):
logger.error('model object is not initialized.')
raise Exception('model object is not initialized.')
if self.preprocessor is None:
self.preprocessor = ImageSkyChangePreprocessor()
logger.info('load model done')
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
res = self.model.forward(**input)
return {OutputKeys.OUTPUT_IMG: res}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
return inputs

View File

@@ -67,6 +67,7 @@ class CVTasks(object):
image_denoising = 'image-denoising'
image_portrait_enhancement = 'image-portrait-enhancement'
image_inpainting = 'image-inpainting'
image_skychange = 'image-skychange'
# image generation
image_to_image_translation = 'image-to-image-translation'

View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import os.path as osp
import unittest
import cv2
import torch
import modelscope
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.base import Pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
print(modelscope.version.__release_datetime__)
class ImageSkychangeTest(unittest.TestCase):
def setUp(self) -> None:
self.model = 'damo/cv_hrnetocr_skychange'
self.sky_image = 'data/test/images/sky_image.jpg'
self.scene_image = 'data/test/images/scene_image.jpg'
self.input = {
'sky_image': self.sky_image,
'scene_image': self.scene_image,
}
def pipeline_inference(self, pipeline: Pipeline, input: str):
result = pipeline(input)
if result is not None:
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
print(f'Output written to {osp.abspath("result.png")}')
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_modelhub(self):
image_skychange = pipeline(Tasks.image_skychange, model=self.model)
self.pipeline_inference(image_skychange, self.input)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_modelhub_default_model(self):
image_skychange = pipeline(Tasks.image_skychange)
self.pipeline_inference(image_skychange, self.input)
if __name__ == '__main__':
unittest.main()

View File

@@ -41,6 +41,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_image_matting.py
- test_skin_retouching.py
- test_table_recognition.py
- test_image_skychange.py
envs:
default: # default env, case not in other env will in default, pytorch.