mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
add image skychange
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10947701
This commit is contained in:
3
data/test/images/scene_image.jpg
Normal file
3
data/test/images/scene_image.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:260cd09f340b86007dd471cba742f82bae0fb5cfd4b8d87265bff5ad2c2c857f
|
||||
size 652482
|
||||
3
data/test/images/sky_image.jpg
Normal file
3
data/test/images/sky_image.jpg
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:679c86d5a82c9c1c4866b5e16b98a2128a57e3ea60f77d56e5f0fe79ab7d746f
|
||||
size 505993
|
||||
@@ -57,6 +57,7 @@ class Models(object):
|
||||
face_emotion = 'face-emotion'
|
||||
product_segmentation = 'product-segmentation'
|
||||
image_body_reshaping = 'image-body-reshaping'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
|
||||
@@ -243,6 +244,7 @@ class Pipelines(object):
|
||||
product_segmentation = 'product-segmentation'
|
||||
image_body_reshaping = 'flow-based-body-reshaping'
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
|
||||
@@ -389,6 +391,7 @@ class Preprocessors(object):
|
||||
movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor'
|
||||
image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor'
|
||||
object_detection_scrfd = 'object-detection-scrfd'
|
||||
image_sky_change_preprocessor = 'image-sky-change-preprocessor'
|
||||
|
||||
# nlp preprocessor
|
||||
sen_sim_tokenizer = 'sen-sim-tokenizer'
|
||||
|
||||
22
modelscope/models/cv/image_skychange/__init__.py
Normal file
22
modelscope/models/cv/image_skychange/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .skychange_model import ImageSkychange
|
||||
from .preprocessor import ImageSkyChangePreprocessor
|
||||
|
||||
else:
|
||||
_import_structure = {'skychange_model': ['ImageSkychange']}
|
||||
_import_structure = {'preprocessor': ['ImageSkyChangePreprocessor']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
245
modelscope/models/cv/image_skychange/preprocessor.py
Normal file
245
modelscope/models/cv/image_skychange/preprocessor.py
Normal file
@@ -0,0 +1,245 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numbers
|
||||
import pdb
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.preprocessors.builder import PREPROCESSORS
|
||||
from modelscope.preprocessors.image import LoadImage
|
||||
from modelscope.utils.constant import Fields, ModeKeys
|
||||
|
||||
_cv2_pad_to_str = {
|
||||
'constant': cv2.BORDER_CONSTANT,
|
||||
'edge': cv2.BORDER_REPLICATE,
|
||||
'reflect': cv2.BORDER_REFLECT_101,
|
||||
'symmetric': cv2.BORDER_REFLECT,
|
||||
}
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.cv, module_name=Preprocessors.image_sky_change_preprocessor)
|
||||
class ImageSkyChangePreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self,
|
||||
model_dir: str = None,
|
||||
mode: str = ModeKeys.INFERENCE,
|
||||
coarse_model_width=640,
|
||||
coarse_model_height=640,
|
||||
refine_model_width=1280,
|
||||
refine_model_height=1280,
|
||||
mean_vec=[0.485, 0.456, 0.406],
|
||||
std_vec=[0.229, 0.224, 0.225],
|
||||
*args,
|
||||
**kwargs):
|
||||
"""
|
||||
Args:
|
||||
model_dir (str): model directory to initialize some resource.
|
||||
mode: The mode for the preprocessor.
|
||||
coarse_model_width: required width of input tensor of coarse model.
|
||||
coarse_model_height: required height of input tensor of coarse model.
|
||||
refine_model_width: required width of input tensor of refine model.
|
||||
refine_model_height: required height of input tensor of refine model.
|
||||
mean_vec: mean of dataset(for transforms.Normalize), default is mean of Imagenet dataset.
|
||||
std_vec: standard deviation of dataset(for transforms.Normalize), default is std of Imagenet dataset.
|
||||
"""
|
||||
super().__init__(mode)
|
||||
|
||||
# set preprocessor info
|
||||
self.coarse_input_size = [coarse_model_width, coarse_model_height]
|
||||
self.refine_input_size = [refine_model_width, refine_model_height]
|
||||
self.normalize = transforms.Normalize(mean=mean_vec, std=std_vec)
|
||||
|
||||
def __call__(self, data: Union[str, Dict], **kwargs) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
Args:
|
||||
data (dict): data dict containing following info:
|
||||
sky_image, scene_image
|
||||
example:
|
||||
```python
|
||||
{
|
||||
"sky_image": "xxx.jpg" # sky_image path(str)
|
||||
"scene_image": "xxx.jpg", # scene_image path(str)
|
||||
}
|
||||
```
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
{
|
||||
"sky_image": the preprocessed sky image(origin size)
|
||||
"sky_image_refine": the preprocessed resized sky image
|
||||
"scene_image": the preprocessed scene image(origin size)
|
||||
"scene_image_refine": the preprocessed resized scene image
|
||||
"img_metas": informations of preprocessed images, e.g. origin shape, pad information, resized shape.
|
||||
}
|
||||
"""
|
||||
if 'sky_image' not in data.keys():
|
||||
raise Exception('sky_image not in input data')
|
||||
if 'scene_image' not in data.keys():
|
||||
raise Exception('scene_image not in input data')
|
||||
if isinstance(data['sky_image'], str):
|
||||
sky_image = LoadImage.convert_to_ndarray(data['sky_image'])
|
||||
sky_image = sky_image.astype(np.uint8) # RGB
|
||||
sky_image = cv2.cvtColor(sky_image, cv2.COLOR_RGB2BGR) # BGR
|
||||
if sky_image is not None:
|
||||
sky_image = self.check_image(sky_image)
|
||||
else:
|
||||
raise Exception('sky_image is None')
|
||||
else:
|
||||
raise Exception('sky_image(path of sky image) is not valid')
|
||||
if isinstance(data['scene_image'], str):
|
||||
scene_image = LoadImage.convert_to_ndarray(data['scene_image'])
|
||||
scene_image = scene_image.astype(np.uint8) # RGB
|
||||
scene_image = cv2.cvtColor(scene_image, cv2.COLOR_RGB2BGR) # BGR
|
||||
if scene_image is not None:
|
||||
scene_image = self.check_image(scene_image)
|
||||
else:
|
||||
raise Exception('scene_image is None')
|
||||
else:
|
||||
raise Exception('scene_image(path of scene image) is not valid')
|
||||
data = {}
|
||||
sky_image_refine, sky_img_metas = self.process_single_img(sky_image)
|
||||
scene_image_refine, scene_img_metas = self.process_single_img(
|
||||
scene_image)
|
||||
data['sky_image'] = sky_image
|
||||
data['sky_image_refine'] = sky_image_refine
|
||||
data['scene_image'] = scene_image
|
||||
data['scene_image_refine'] = scene_image_refine
|
||||
data['img_metas'] = {
|
||||
'sky_img_metas': sky_img_metas,
|
||||
'scene_img_metas': scene_img_metas,
|
||||
'input_size': {
|
||||
'coarse_input_size': self.coarse_input_size,
|
||||
'refine_input_size': self.refine_input_size
|
||||
}
|
||||
}
|
||||
return data
|
||||
|
||||
def process_single_img(self, img):
|
||||
img_metas = {}
|
||||
img_metas['ori_shape'] = img.shape[0:2] # img: (origin_h, origin_w, 3)
|
||||
img, pad_direction = get_refine_input(img, self.refine_input_size)
|
||||
img = image_transform(
|
||||
img, self.normalize) # torch.Size([3, refine_net_h, refine_net_w])
|
||||
img = img.unsqueeze(0)
|
||||
img_metas['pad_direction'] = pad_direction
|
||||
img_metas['refine_shape'] = img.shape[
|
||||
2:] # torch.Size([1, 3, refine_net_h, refine_net_w])
|
||||
return img, img_metas
|
||||
|
||||
def check_image(self, input_img):
|
||||
whole_temp_shape = input_img.shape
|
||||
if len(whole_temp_shape) == 2:
|
||||
input_img = np.stack([input_img, input_img, input_img], axis=2)
|
||||
elif whole_temp_shape[2] == 1:
|
||||
input_img = np.concatenate([input_img, input_img, input_img],
|
||||
axis=2)
|
||||
elif whole_temp_shape[2] == 4:
|
||||
input_img = input_img[:, :,
|
||||
0:3] * 1.0 * input_img[:, :,
|
||||
3:4] * 1.0 / 255.0
|
||||
return input_img
|
||||
|
||||
|
||||
def get_refine_input(mat, refine_input_size):
|
||||
# maxDimMatch: resize
|
||||
mat = max_dim_match(mat, refine_input_size)
|
||||
# pad image to refine net input size
|
||||
mat, pad_direction = center_pad_image_withwh(mat, refine_input_size, 0)
|
||||
return mat, pad_direction
|
||||
|
||||
|
||||
def max_dim_match(image, refine_model_size):
|
||||
h, w, c = np.shape(image)
|
||||
resize_w, resize_h = refine_model_size
|
||||
if h != resize_h or w != resize_w:
|
||||
h_scale = float(resize_h) / h
|
||||
w_scale = float(resize_w) / w
|
||||
resize_scale = min(w_scale, h_scale)
|
||||
new_h = int(h * resize_scale + 0.5)
|
||||
new_w = int(w * resize_scale + 0.5)
|
||||
image = cv2.resize(
|
||||
image, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
||||
return image
|
||||
|
||||
|
||||
def center_pad_image_withwh(image,
|
||||
crop_size,
|
||||
padvalue,
|
||||
padding_mode='constant'):
|
||||
pad_image = image
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
pad_h = max(crop_size[1] - h, 0)
|
||||
pad_w = max(crop_size[0] - w, 0)
|
||||
pad_direction = (0, 0, 0, 0)
|
||||
if pad_h > 0 or pad_w > 0:
|
||||
half_w = int(pad_w / 2 + 0.5)
|
||||
half_h = int(pad_h / 2 + 0.5)
|
||||
pad_direction = (half_w, half_h, pad_w - half_w, pad_h - half_h)
|
||||
pad_image = pad(
|
||||
image, pad_direction, padvalue, padding_mode=padding_mode)
|
||||
return pad_image, pad_direction
|
||||
|
||||
|
||||
def pad(img, padding, fill=0, padding_mode='constant'):
|
||||
if not is_numpy_image(img):
|
||||
raise TypeError('img should be numpy ndarray. Got {}'.format(
|
||||
type(img)))
|
||||
if not isinstance(padding,
|
||||
(numbers.Number, tuple, list)) or len(padding) != 4:
|
||||
raise TypeError('Got inappropriate padding arg')
|
||||
|
||||
pad_left = padding[0]
|
||||
pad_top = padding[1]
|
||||
pad_right = padding[2]
|
||||
pad_bottom = padding[3]
|
||||
|
||||
shape_len = len(img.shape)
|
||||
if shape_len == 2:
|
||||
return cv2.copyMakeBorder(
|
||||
img,
|
||||
top=pad_top,
|
||||
bottom=pad_bottom,
|
||||
left=pad_left,
|
||||
right=pad_right,
|
||||
borderType=_cv2_pad_to_str[padding_mode],
|
||||
value=fill,
|
||||
)
|
||||
elif shape_len == 3 and img.shape[2] == 1:
|
||||
return cv2.copyMakeBorder(
|
||||
img,
|
||||
top=pad_top,
|
||||
bottom=pad_bottom,
|
||||
left=pad_left,
|
||||
right=pad_right,
|
||||
borderType=_cv2_pad_to_str[padding_mode],
|
||||
value=fill,
|
||||
)[:, :, np.newaxis]
|
||||
else:
|
||||
return cv2.copyMakeBorder(
|
||||
img,
|
||||
top=pad_top,
|
||||
bottom=pad_bottom,
|
||||
left=pad_left,
|
||||
right=pad_right,
|
||||
borderType=_cv2_pad_to_str[padding_mode],
|
||||
value=fill,
|
||||
)
|
||||
|
||||
|
||||
def image_transform(img, normalize):
|
||||
img = img[:, :, ::-1] # BGR-->RGB to pil format
|
||||
img = img.transpose((2, 0, 1)) # h,w,c --> c,h,w
|
||||
img = img.astype(np.float32) / 255
|
||||
img = normalize(torch.from_numpy(img.copy()))
|
||||
return img
|
||||
|
||||
|
||||
def is_numpy_image(img):
|
||||
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
|
||||
118
modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py
Normal file
118
modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py
Normal file
@@ -0,0 +1,118 @@
|
||||
# The implementation is adopted from ASPP made publicly available under the MIT License License
|
||||
# at https://github.com/jfzhang95/pytorch-deeplab-xception
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
BatchNorm2d = nn.BatchNorm2d
|
||||
|
||||
|
||||
class ASPPModule(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, planes, kernel_size, padding, dilation,
|
||||
BatchNorm):
|
||||
super(ASPPModule, self).__init__()
|
||||
self.atrous_conv = nn.Conv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=False)
|
||||
self.bn = BatchNorm(planes)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
self._init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.atrous_conv(x)
|
||||
x = self.bn(x)
|
||||
|
||||
return self.relu(x)
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
torch.nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
|
||||
# this aspp is official version
|
||||
# copy from :https://github.com/jfzhang95/pytorch-deeplab-xception
|
||||
class ASPP(nn.Module):
|
||||
|
||||
def __init__(self, inplanes, outplanes, dilations, drop_rate=0.1):
|
||||
super(ASPP, self).__init__()
|
||||
|
||||
self.aspp1 = ASPPModule(
|
||||
inplanes,
|
||||
outplanes,
|
||||
1,
|
||||
padding=0,
|
||||
dilation=dilations[0],
|
||||
BatchNorm=BatchNorm2d)
|
||||
self.aspp2 = ASPPModule(
|
||||
inplanes,
|
||||
outplanes,
|
||||
3,
|
||||
padding=dilations[1],
|
||||
dilation=dilations[1],
|
||||
BatchNorm=BatchNorm2d)
|
||||
self.aspp3 = ASPPModule(
|
||||
inplanes,
|
||||
outplanes,
|
||||
3,
|
||||
padding=dilations[2],
|
||||
dilation=dilations[2],
|
||||
BatchNorm=BatchNorm2d)
|
||||
self.aspp4 = ASPPModule(
|
||||
inplanes,
|
||||
outplanes,
|
||||
3,
|
||||
padding=dilations[3],
|
||||
dilation=dilations[3],
|
||||
BatchNorm=BatchNorm2d)
|
||||
|
||||
self.global_avg_pool = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False),
|
||||
BatchNorm2d(outplanes), nn.ReLU())
|
||||
self.conv1 = nn.Conv2d(outplanes * 5, outplanes, 1, bias=False)
|
||||
self.bn1 = BatchNorm2d(outplanes)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(drop_rate)
|
||||
self._init_weight()
|
||||
|
||||
def forward(self, x): # [1, 256, 320, 320]
|
||||
x1 = self.aspp1(x) # [1, 128, 160, 160]
|
||||
x2 = self.aspp2(x) # [1, 128, 160, 160]
|
||||
x3 = self.aspp3(x) # [1, 128, 160, 160]
|
||||
x4 = self.aspp4(x) # [1, 128, 160, 160]
|
||||
x5 = self.global_avg_pool(x) # b,c,h,w [1, 128, 1, 1]
|
||||
x5 = F.interpolate(
|
||||
x5, size=x4.size()[2:], mode='bilinear',
|
||||
align_corners=True) # [1, 128, 160, 160]
|
||||
x = torch.cat((x1, x2, x3, x4, x5), dim=1) # [1, 640, 160, 160]
|
||||
|
||||
x = self.conv1(x) # [1, 640, 160, 160]
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return self.dropout(x)
|
||||
|
||||
def _init_weight(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
torch.nn.init.kaiming_normal_(m.weight)
|
||||
elif isinstance(m, BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
620
modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py
Normal file
620
modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py
Normal file
@@ -0,0 +1,620 @@
|
||||
# The implementation is adopted from HRNet, made publicly available under the MIT License License
|
||||
# at https://github.com/HRNet/HRNet-Semantic-Segmentation
|
||||
import logging
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
BatchNorm2d = nn.BatchNorm2d
|
||||
|
||||
BN_MOMENTUM = 0.1
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM)
|
||||
self.conv3 = nn.Conv2d(
|
||||
planes, planes * self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HighResolutionModule(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_branches,
|
||||
blocks,
|
||||
num_blocks,
|
||||
num_inchannels,
|
||||
num_channels,
|
||||
fuse_method,
|
||||
multi_scale_output=True):
|
||||
super(HighResolutionModule, self).__init__()
|
||||
self._check_branches(num_branches, blocks, num_blocks, num_inchannels,
|
||||
num_channels)
|
||||
|
||||
self.num_inchannels = num_inchannels
|
||||
self.fuse_method = fuse_method
|
||||
self.num_branches = num_branches
|
||||
|
||||
self.multi_scale_output = multi_scale_output
|
||||
|
||||
self.branches = self._make_branches(num_branches, blocks, num_blocks,
|
||||
num_channels)
|
||||
self.fuse_layers = self._make_fuse_layers()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels,
|
||||
num_channels):
|
||||
if num_branches != len(num_blocks):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
|
||||
num_branches, len(num_blocks))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_channels):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
|
||||
num_branches, len(num_channels))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
if num_branches != len(num_inchannels):
|
||||
error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
|
||||
num_branches, len(num_inchannels))
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
def _make_one_branch(self,
|
||||
branch_index,
|
||||
block,
|
||||
num_blocks,
|
||||
num_channels,
|
||||
stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.num_inchannels[
|
||||
branch_index] != num_channels[branch_index] * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
self.num_inchannels[branch_index],
|
||||
num_channels[branch_index] * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
BatchNorm2d(
|
||||
num_channels[branch_index] * block.expansion,
|
||||
momentum=BN_MOMENTUM),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(
|
||||
block(self.num_inchannels[branch_index],
|
||||
num_channels[branch_index], stride, downsample))
|
||||
self.num_inchannels[
|
||||
branch_index] = num_channels[branch_index] * block.expansion
|
||||
for i in range(1, num_blocks[branch_index]):
|
||||
layers.append(
|
||||
block(self.num_inchannels[branch_index],
|
||||
num_channels[branch_index]))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_branches(self, num_branches, block, num_blocks, num_channels):
|
||||
branches = []
|
||||
|
||||
for i in range(num_branches):
|
||||
branches.append(
|
||||
self._make_one_branch(i, block, num_blocks, num_channels))
|
||||
|
||||
return nn.ModuleList(branches)
|
||||
|
||||
def _make_fuse_layers(self):
|
||||
if self.num_branches == 1:
|
||||
return None
|
||||
|
||||
num_branches = self.num_branches
|
||||
num_inchannels = self.num_inchannels # tuple
|
||||
fuse_layers = []
|
||||
for i in range(num_branches if self.multi_scale_output else 1):
|
||||
fuse_layer = []
|
||||
for j in range(num_branches):
|
||||
if j > i:
|
||||
fuse_layer.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
num_inchannels[j],
|
||||
num_inchannels[i],
|
||||
1,
|
||||
1,
|
||||
0,
|
||||
bias=False),
|
||||
BatchNorm2d(
|
||||
num_inchannels[i], momentum=BN_MOMENTUM)))
|
||||
elif j == i:
|
||||
fuse_layer.append(None)
|
||||
else:
|
||||
conv3x3s = []
|
||||
for k in range(i - j):
|
||||
if k == i - j - 1:
|
||||
num_outchannels_conv3x3 = num_inchannels[i]
|
||||
conv3x3s.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
num_inchannels[j],
|
||||
num_outchannels_conv3x3,
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
bias=False),
|
||||
BatchNorm2d(
|
||||
num_outchannels_conv3x3,
|
||||
momentum=BN_MOMENTUM)))
|
||||
else:
|
||||
num_outchannels_conv3x3 = num_inchannels[j]
|
||||
conv3x3s.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
num_inchannels[j],
|
||||
num_outchannels_conv3x3,
|
||||
3,
|
||||
2,
|
||||
1,
|
||||
bias=False),
|
||||
BatchNorm2d(
|
||||
num_outchannels_conv3x3,
|
||||
momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)))
|
||||
fuse_layer.append(nn.Sequential(*conv3x3s))
|
||||
fuse_layers.append(nn.ModuleList(fuse_layer))
|
||||
|
||||
return nn.ModuleList(fuse_layers)
|
||||
|
||||
def get_num_inchannels(self):
|
||||
return self.num_inchannels
|
||||
|
||||
def forward(self, x):
|
||||
if self.num_branches == 1:
|
||||
return [self.branches[0](x[0])]
|
||||
|
||||
for i in range(self.num_branches):
|
||||
x[i] = self.branches[i](x[i])
|
||||
|
||||
x_fuse = []
|
||||
for i in range(len(self.fuse_layers)):
|
||||
y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
|
||||
for j in range(1, self.num_branches):
|
||||
if i == j:
|
||||
y = y + x[j]
|
||||
elif j > i:
|
||||
width_output = x[i].shape[-1]
|
||||
height_output = x[i].shape[-2]
|
||||
y = y + F.interpolate(
|
||||
self.fuse_layers[i][j](x[j]),
|
||||
size=(height_output, width_output),
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
else:
|
||||
y = y + self.fuse_layers[i][j](x[j])
|
||||
x_fuse.append(self.relu(y))
|
||||
|
||||
return x_fuse
|
||||
|
||||
|
||||
model_w18v1 = {
|
||||
'STAGE1': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 1,
|
||||
'BLOCK': 'BOTTLENECK',
|
||||
'NUM_BLOCKS': (1),
|
||||
'NUM_CHANNELS': (32),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE2': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 2,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2),
|
||||
'NUM_CHANNELS': (16, 32),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE3': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 3,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2, 2),
|
||||
'NUM_CHANNELS': (16, 32, 64),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE4': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 4,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2, 2, 2),
|
||||
'NUM_CHANNELS': (16, 32, 64, 128),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'FINAL_CONV_KERNEL': 1
|
||||
}
|
||||
|
||||
model_w18v2 = {
|
||||
'STAGE1': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 1,
|
||||
'BLOCK': 'BOTTLENECK',
|
||||
'NUM_BLOCKS': (2),
|
||||
'NUM_CHANNELS': (64),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE2': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 2,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2),
|
||||
'NUM_CHANNELS': (18, 36),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE3': {
|
||||
'NUM_MODULES': 3,
|
||||
'NUM_BRANCHES': 3,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2, 2),
|
||||
'NUM_CHANNELS': (18, 36, 72),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE4': {
|
||||
'NUM_MODULES': 2,
|
||||
'NUM_BRANCHES': 4,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (2, 2, 2, 2),
|
||||
'NUM_CHANNELS': (18, 36, 72, 144),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'FINAL_CONV_KERNEL': 1
|
||||
}
|
||||
|
||||
model_w48 = {
|
||||
'STAGE1': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 1,
|
||||
'BLOCK': 'BOTTLENECK',
|
||||
'NUM_BLOCKS': (4),
|
||||
'NUM_CHANNELS': (64),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE2': {
|
||||
'NUM_MODULES': 1,
|
||||
'NUM_BRANCHES': 2,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (4, 4),
|
||||
'NUM_CHANNELS': (48, 96),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE3': {
|
||||
'NUM_MODULES': 4,
|
||||
'NUM_BRANCHES': 3,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (4, 4, 4),
|
||||
'NUM_CHANNELS': (48, 96, 192),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'STAGE4': {
|
||||
'NUM_MODULES': 3,
|
||||
'NUM_BRANCHES': 4,
|
||||
'BLOCK': 'BASIC',
|
||||
'NUM_BLOCKS': (4, 4, 4, 4),
|
||||
'NUM_CHANNELS': (48, 96, 192, 384),
|
||||
'FUSE_METHOD': 'SUM'
|
||||
},
|
||||
'FINAL_CONV_KERNEL': 1
|
||||
}
|
||||
|
||||
model_version_dict = {}
|
||||
model_version_dict['w48'] = model_w48
|
||||
model_version_dict['w18v1'] = model_w18v1
|
||||
model_version_dict['w18v2'] = model_w18v2
|
||||
|
||||
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
|
||||
|
||||
|
||||
class HrnetBackBone(nn.Module):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(HrnetBackBone, self).__init__()
|
||||
|
||||
assert 'version' in kwargs, 'hrnet not exist model version'
|
||||
extra = model_version_dict[kwargs['version']]
|
||||
|
||||
# stem net
|
||||
self.conv1 = nn.Conv2d(
|
||||
3, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.conv2 = nn.Conv2d(
|
||||
64, 64, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
self.stage1_cfg = extra['STAGE1']
|
||||
num_channels = self.stage1_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage1_cfg['BLOCK']]
|
||||
num_blocks = self.stage1_cfg['NUM_BLOCKS']
|
||||
self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
|
||||
stage1_out_channel = block.expansion * num_channels
|
||||
|
||||
self.stage2_cfg = extra['STAGE2']
|
||||
num_channels = self.stage2_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage2_cfg['BLOCK']]
|
||||
num_channels = [
|
||||
num_channels[i] * block.expansion
|
||||
for i in range(len(num_channels))
|
||||
]
|
||||
self.transition1 = self._make_transition_layer([stage1_out_channel],
|
||||
num_channels)
|
||||
self.stage2, pre_stage_channels = self._make_stage(
|
||||
self.stage2_cfg, num_channels)
|
||||
|
||||
self.stage3_cfg = extra['STAGE3']
|
||||
num_channels = self.stage3_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage3_cfg['BLOCK']]
|
||||
num_channels = [
|
||||
num_channels[i] * block.expansion
|
||||
for i in range(len(num_channels))
|
||||
]
|
||||
self.transition2 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage3, pre_stage_channels = self._make_stage(
|
||||
self.stage3_cfg, num_channels)
|
||||
|
||||
self.stage4_cfg = extra['STAGE4']
|
||||
num_channels = self.stage4_cfg['NUM_CHANNELS']
|
||||
block = blocks_dict[self.stage4_cfg['BLOCK']]
|
||||
num_channels = [
|
||||
num_channels[i] * block.expansion
|
||||
for i in range(len(num_channels))
|
||||
]
|
||||
self.transition3 = self._make_transition_layer(pre_stage_channels,
|
||||
num_channels)
|
||||
self.stage4, pre_stage_channels = self._make_stage(
|
||||
self.stage4_cfg, num_channels, multi_scale_output=True)
|
||||
|
||||
self.backbone_last_inp_channels = np.int(np.sum(pre_stage_channels))
|
||||
|
||||
def _make_transition_layer(self, num_channels_pre_layer,
|
||||
num_channels_cur_layer):
|
||||
num_branches_cur = len(num_channels_cur_layer)
|
||||
num_branches_pre = len(num_channels_pre_layer)
|
||||
|
||||
transition_layers = []
|
||||
for i in range(num_branches_cur):
|
||||
if i < num_branches_pre:
|
||||
if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
|
||||
transition_layers.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
num_channels_pre_layer[i],
|
||||
num_channels_cur_layer[i],
|
||||
3,
|
||||
1,
|
||||
1,
|
||||
bias=False),
|
||||
BatchNorm2d(
|
||||
num_channels_cur_layer[i],
|
||||
momentum=BN_MOMENTUM), nn.ReLU(inplace=True)))
|
||||
else:
|
||||
transition_layers.append(None)
|
||||
else:
|
||||
conv3x3s = []
|
||||
for j in range(i + 1 - num_branches_pre):
|
||||
inchannels = num_channels_pre_layer[-1]
|
||||
outchannels = num_channels_cur_layer[
|
||||
i] if j == i - num_branches_pre else inchannels
|
||||
conv3x3s.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inchannels, outchannels, 3, 2, 1, bias=False),
|
||||
BatchNorm2d(outchannels, momentum=BN_MOMENTUM),
|
||||
nn.ReLU(inplace=True)))
|
||||
transition_layers.append(nn.Sequential(*conv3x3s))
|
||||
|
||||
return nn.ModuleList(transition_layers)
|
||||
|
||||
def _make_layer(self, block, inplanes, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
inplanes,
|
||||
planes * block.expansion,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(inplanes, planes, stride, downsample))
|
||||
inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _make_stage(self,
|
||||
layer_config,
|
||||
num_inchannels,
|
||||
multi_scale_output=True):
|
||||
num_modules = layer_config['NUM_MODULES']
|
||||
num_branches = layer_config['NUM_BRANCHES']
|
||||
num_blocks = layer_config['NUM_BLOCKS']
|
||||
num_channels = layer_config['NUM_CHANNELS']
|
||||
block = blocks_dict[layer_config['BLOCK']]
|
||||
fuse_method = layer_config['FUSE_METHOD']
|
||||
|
||||
modules = []
|
||||
for i in range(num_modules):
|
||||
# multi_scale_output is only used last module
|
||||
if not multi_scale_output and i == num_modules - 1:
|
||||
reset_multi_scale_output = False
|
||||
else:
|
||||
reset_multi_scale_output = True
|
||||
modules.append(
|
||||
HighResolutionModule(num_branches, block, num_blocks,
|
||||
num_inchannels, num_channels, fuse_method,
|
||||
reset_multi_scale_output))
|
||||
num_inchannels = modules[-1].get_num_inchannels()
|
||||
|
||||
return nn.Sequential(*modules), num_inchannels
|
||||
|
||||
def _backbone_forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
x = self.layer1(x)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x)
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
x = self.stage4(x_list)
|
||||
|
||||
# Upsampling
|
||||
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
||||
x1 = F.interpolate(
|
||||
x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
x2 = F.interpolate(
|
||||
x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
x3 = F.interpolate(
|
||||
x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
|
||||
x = torch.cat([x[0], x1, x2, x3], 1)
|
||||
return x
|
||||
|
||||
def init_weights(self, url, cache_file=''):
|
||||
pretrained_dict = load_state(url, model_dir=cache_file)
|
||||
model_dict = self.state_dict()
|
||||
|
||||
model_len = len(model_dict)
|
||||
pretrain_len = len(pretrained_dict)
|
||||
common_dict = {}
|
||||
valid_layer_num = 0
|
||||
for k, v in pretrained_dict.items():
|
||||
if k in model_dict:
|
||||
common_dict[k] = v
|
||||
valid_layer_num += 1
|
||||
|
||||
print('*' * 50)
|
||||
print('Model Param Num:{} Pretrained Param Num:{} '
|
||||
'Commmon Num:{}'.format(model_len, pretrain_len,
|
||||
valid_layer_num))
|
||||
print('-' * 50)
|
||||
print('Model Extra Param Names:\n\t{}'.format(
|
||||
set(model_dict) - set(pretrained_dict)))
|
||||
print('-' * 50)
|
||||
print('Pretrained Extra Param Names:\n\t{}'.format(
|
||||
set(pretrained_dict) - set(model_dict)))
|
||||
print('*' * 50)
|
||||
|
||||
model_dict.update(common_dict)
|
||||
self.load_state_dict(model_dict)
|
||||
@@ -0,0 +1,510 @@
|
||||
# Part of the implementation is borrowed and modified from HRNet,
|
||||
# publicly available under the MIT License License at https://github.com/HRNet/HRNet-Semantic-Segmentation
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .BlockModules import ASPP
|
||||
from .hrnet_backnone import BatchNorm2d, HrnetBackBone, blocks_dict
|
||||
|
||||
ALIGN_CORNERS = True
|
||||
BN_MOMENTUM = 0.1
|
||||
|
||||
|
||||
class ModuleHelper:
|
||||
|
||||
@staticmethod
|
||||
def BNReLU(num_features, bn_type=None, **kwargs):
|
||||
return nn.Sequential(BatchNorm2d(num_features, **kwargs), nn.ReLU())
|
||||
|
||||
@staticmethod
|
||||
def BatchNorm2d(*args, **kwargs):
|
||||
return BatchNorm2d
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
|
||||
class SpatialGatherModule(nn.Module):
|
||||
"""
|
||||
Aggregate the context features according to the initial
|
||||
predicted probability distribution.
|
||||
Employ the soft-weighted method to aggregate the context.
|
||||
"""
|
||||
|
||||
def __init__(self, cls_num=0, scale=1):
|
||||
super(SpatialGatherModule, self).__init__()
|
||||
self.cls_num = cls_num
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, feats, probs):
|
||||
batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size(
|
||||
2), probs.size(3)
|
||||
probs = probs.view(batch_size, c, -1)
|
||||
feats = feats.view(batch_size, feats.size(1), -1)
|
||||
feats = feats.permute(0, 2, 1) # batch x hw x c
|
||||
probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
|
||||
ocr_context = torch.matmul(probs, feats) # batch x k x c
|
||||
|
||||
ocr_context = ocr_context.permute(0, 2,
|
||||
1).unsqueeze(3) # batch x c x k x 1
|
||||
return ocr_context
|
||||
|
||||
|
||||
class ObjectAttentionBlock(nn.Module):
|
||||
'''
|
||||
The basic implementation for object context block
|
||||
Input:
|
||||
N X C X H X W
|
||||
Parameters:
|
||||
in_channels : the dimension of the input feature map
|
||||
key_channels : the dimension after the key/query transform
|
||||
scale : choose the scale to downsample the input feature maps (save memory cost)
|
||||
bn_type : specify the bn type
|
||||
Return:
|
||||
N X C X H X W
|
||||
'''
|
||||
|
||||
def __init__(self, in_channels, key_channels, scale=1, bn_type=None):
|
||||
super(ObjectAttentionBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.in_channels = in_channels
|
||||
self.key_channels = key_channels
|
||||
self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
|
||||
self.f_pixel = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.key_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
|
||||
nn.Conv2d(
|
||||
in_channels=self.key_channels,
|
||||
out_channels=self.key_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
|
||||
)
|
||||
self.f_object = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.key_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
|
||||
nn.Conv2d(
|
||||
in_channels=self.key_channels,
|
||||
out_channels=self.key_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
|
||||
)
|
||||
self.f_down = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.key_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type),
|
||||
)
|
||||
self.f_up = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_channels=self.key_channels,
|
||||
out_channels=self.in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type),
|
||||
)
|
||||
|
||||
def forward(self, x, proxy):
|
||||
batch_size, h, w = x.size(0), x.size(2), x.size(3)
|
||||
if self.scale > 1:
|
||||
x = self.pool(x)
|
||||
|
||||
query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
|
||||
query = query.permute(0, 2, 1)
|
||||
key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
|
||||
value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
|
||||
value = value.permute(0, 2, 1)
|
||||
|
||||
sim_map = torch.matmul(query, key)
|
||||
sim_map = (self.key_channels**-.5) * sim_map
|
||||
sim_map = F.softmax(sim_map, dim=-1)
|
||||
|
||||
# add bg context ...
|
||||
context = torch.matmul(sim_map, value)
|
||||
context = context.permute(0, 2, 1).contiguous()
|
||||
context = context.view(batch_size, self.key_channels, *x.size()[2:])
|
||||
context = self.f_up(context)
|
||||
if self.scale > 1:
|
||||
context = F.interpolate(
|
||||
input=context,
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=ALIGN_CORNERS)
|
||||
|
||||
return context
|
||||
|
||||
|
||||
class ObjectAttentionBlock2D(ObjectAttentionBlock):
|
||||
|
||||
def __init__(self, in_channels, key_channels, scale=1, bn_type=None):
|
||||
super(ObjectAttentionBlock2D, self).__init__(
|
||||
in_channels, key_channels, scale, bn_type=bn_type)
|
||||
|
||||
|
||||
class SpatialOCRModule(nn.Module):
|
||||
"""
|
||||
Implementation of the OCR module:
|
||||
We aggregate the global object representation to update the representation for each pixel.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
key_channels,
|
||||
out_channels,
|
||||
scale=1,
|
||||
dropout=0.1,
|
||||
bn_type=None):
|
||||
super(SpatialOCRModule, self).__init__()
|
||||
self.object_context_block = ObjectAttentionBlock2D(
|
||||
in_channels, key_channels, scale, bn_type)
|
||||
_in_channels = 2 * in_channels
|
||||
|
||||
self.conv_bn_dropout = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
_in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
bias=False),
|
||||
ModuleHelper.BNReLU(out_channels, bn_type=bn_type),
|
||||
nn.Dropout2d(dropout))
|
||||
|
||||
def forward(self, feats, proxy_feats):
|
||||
context = self.object_context_block(feats, proxy_feats)
|
||||
|
||||
output = self.conv_bn_dropout(torch.cat([context, feats], 1))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class HrnetSuperAndOcr(HrnetBackBone):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super(HrnetSuperAndOcr, self).__init__(**kwargs)
|
||||
if 'architecture' not in kwargs:
|
||||
raise Exception('HrnetSuperAndOcr not exist architecture param!')
|
||||
self.architecture = kwargs['architecture']
|
||||
|
||||
if 'class_num' not in kwargs:
|
||||
raise Exception('HrnetSuperAndOcr not exist class_num param!')
|
||||
self.class_num = kwargs['class_num']
|
||||
|
||||
if 'ocr' not in kwargs:
|
||||
raise Exception('HrnetSuperAndOcr not exist ocr param!')
|
||||
ocr_mid_channels = kwargs['ocr']['mid_channels']
|
||||
ocr_key_channels = kwargs['ocr']['key_channels']
|
||||
dropout_rate = kwargs['ocr']['dropout_rate']
|
||||
scale = kwargs['ocr']['scale']
|
||||
|
||||
if 'super_param' not in kwargs:
|
||||
raise Exception('HrnetSuperAndOcr not exist super_param param!')
|
||||
|
||||
self.super_dict = kwargs['super_param']
|
||||
|
||||
self.is_export_onnx = False
|
||||
self.is_export_full_onnx = False
|
||||
|
||||
self.is_contain_tail = True if 'tail_param' in kwargs else False
|
||||
if self.is_contain_tail:
|
||||
self.stage_tail_dict = kwargs['tail_param']
|
||||
num_channels = self.stage_tail_dict['NUM_CHANNELS'][0]
|
||||
block = blocks_dict[self.stage_tail_dict['BLOCK']]
|
||||
num_blocks = self.stage_tail_dict['NUM_BLOCKS'][0]
|
||||
self.stage_tail = self._make_layer(block,
|
||||
self.backbone_last_inp_channels,
|
||||
num_channels, num_blocks)
|
||||
last_inp_channels = block.expansion * num_channels
|
||||
else:
|
||||
last_inp_channels = self.backbone_last_inp_channels
|
||||
|
||||
self.is_contain_aspp = True if 'aspp' in kwargs else False
|
||||
|
||||
if self.architecture == 'hrnet_super_ocr':
|
||||
self.is_ocr_first = False
|
||||
num_channels = [64, last_inp_channels]
|
||||
self.stage_super, super_stage_channels = self._make_stage(
|
||||
self.super_dict, num_channels)
|
||||
last_inp_channels = np.int(np.sum(super_stage_channels))
|
||||
|
||||
if self.is_contain_aspp:
|
||||
aspp_param = kwargs['aspp']
|
||||
self.aspp_layer = ASPP(
|
||||
inplanes=last_inp_channels,
|
||||
outplanes=aspp_param['outplanes'],
|
||||
dilations=aspp_param['dilations'],
|
||||
drop_rate=aspp_param['drop_rate'])
|
||||
last_inp_channels = aspp_param['outplanes']
|
||||
|
||||
self.aux_head = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
last_inp_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0), BatchNorm2d(last_inp_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
self.class_num,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True))
|
||||
|
||||
self.conv3x3_ocr = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
ocr_mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
BatchNorm2d(ocr_mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.ocr_gather_head = SpatialGatherModule(self.class_num)
|
||||
|
||||
self.ocr_distri_head = SpatialOCRModule(
|
||||
in_channels=ocr_mid_channels,
|
||||
key_channels=ocr_key_channels,
|
||||
out_channels=ocr_mid_channels,
|
||||
scale=scale,
|
||||
dropout=dropout_rate,
|
||||
)
|
||||
|
||||
self.cls_head = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
ocr_mid_channels,
|
||||
ocr_mid_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0), BatchNorm2d(ocr_mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
ocr_mid_channels,
|
||||
self.class_num,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True))
|
||||
else:
|
||||
self.is_ocr_first = True
|
||||
|
||||
if self.is_contain_aspp:
|
||||
aspp_param = kwargs['aspp']
|
||||
self.aspp_layer = ASPP(
|
||||
inplanes=last_inp_channels,
|
||||
outplanes=aspp_param['outplanes'],
|
||||
dilations=aspp_param['dilations'],
|
||||
drop_rate=aspp_param['drop_rate'])
|
||||
last_inp_channels = aspp_param['outplanes']
|
||||
|
||||
self.aux_head = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
last_inp_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0), BatchNorm2d(last_inp_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
self.class_num,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True))
|
||||
|
||||
self.conv3x3_ocr = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
ocr_mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1),
|
||||
BatchNorm2d(ocr_mid_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
)
|
||||
self.ocr_gather_head = SpatialGatherModule(self.class_num)
|
||||
|
||||
self.ocr_distri_head = SpatialOCRModule(
|
||||
in_channels=ocr_mid_channels,
|
||||
key_channels=ocr_key_channels,
|
||||
out_channels=ocr_mid_channels,
|
||||
scale=scale,
|
||||
dropout=dropout_rate,
|
||||
)
|
||||
|
||||
num_channels = [64, ocr_mid_channels]
|
||||
self.stage_super, super_stage_channels = self._make_stage(
|
||||
self.super_dict, num_channels)
|
||||
last_inp_channels = np.int(np.sum(super_stage_channels))
|
||||
|
||||
self.cls_head = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
last_inp_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0), BatchNorm2d(last_inp_channels),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(
|
||||
last_inp_channels,
|
||||
self.class_num,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=True))
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_export_onnx:
|
||||
x = x.permute(0, 3, 1, 2)
|
||||
raw_h, raw_w = x.size(2), x.size(3)
|
||||
if self.is_export_full_onnx:
|
||||
raw_h, raw_w = x.size(2), x.size(3)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x) # 5, 64, 320, 320
|
||||
x_stem = self.relu(x)
|
||||
x = self.conv2(x_stem)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x) # 5, 64, 160, 160
|
||||
|
||||
x = self.layer1(x) # 5, 256=64*4, 160, 160
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage2_cfg['NUM_BRANCHES']):
|
||||
if self.transition1[i] is not None:
|
||||
x_list.append(self.transition1[i](x))
|
||||
else:
|
||||
x_list.append(x) # [[5, 18, 160, 160],[5, 36, 80, 80]]
|
||||
y_list = self.stage2(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage3_cfg['NUM_BRANCHES']):
|
||||
if self.transition2[i] is not None:
|
||||
x_list.append(self.transition2[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
y_list = self.stage3(x_list)
|
||||
|
||||
x_list = []
|
||||
for i in range(self.stage4_cfg['NUM_BRANCHES']):
|
||||
if self.transition3[i] is not None:
|
||||
x_list.append(self.transition3[i](y_list[-1]))
|
||||
else:
|
||||
x_list.append(y_list[i])
|
||||
x = self.stage4(x_list)
|
||||
|
||||
# Upsampling
|
||||
x0_h, x0_w = x[0].size(2), x[0].size(3)
|
||||
x1 = F.interpolate(
|
||||
x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
x2 = F.interpolate(
|
||||
x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
x3 = F.interpolate(
|
||||
x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True)
|
||||
|
||||
feats = torch.cat([x[0], x1, x2, x3], 1)
|
||||
|
||||
if self.is_contain_tail:
|
||||
feats = self.stage_tail(feats)
|
||||
|
||||
if self.is_ocr_first:
|
||||
|
||||
if self.is_contain_aspp:
|
||||
feats = self.aspp_layer(feats)
|
||||
# compute contrast feature
|
||||
out_aux = self.aux_head(feats)
|
||||
|
||||
feats = self.conv3x3_ocr(feats)
|
||||
context = self.ocr_gather_head(feats, out_aux)
|
||||
feats = self.ocr_distri_head(feats, context)
|
||||
|
||||
feats = [x_stem, feats] # 320*320 2X
|
||||
x_super = self.stage_super(feats)
|
||||
|
||||
xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3)
|
||||
x_super1 = F.interpolate(
|
||||
x_super[1],
|
||||
size=(xsuper_h, xsuper_w),
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
x_super = torch.cat([x_super[0], x_super1], 1)
|
||||
out = self.cls_head(x_super)
|
||||
|
||||
else:
|
||||
x_super = [x_stem, feats] # 320*320 2X, 160*160 4X
|
||||
x_super = self.stage_super(x_super)
|
||||
|
||||
xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3)
|
||||
x_super1 = F.interpolate(
|
||||
x_super[1],
|
||||
size=(xsuper_h, xsuper_w),
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
x_super = torch.cat([x_super[0], x_super1], 1)
|
||||
|
||||
if self.is_contain_aspp:
|
||||
x_super = self.aspp_layer(x_super)
|
||||
out_aux = self.aux_head(x_super)
|
||||
|
||||
feats = self.conv3x3_ocr(x_super)
|
||||
context = self.ocr_gather_head(feats, out_aux)
|
||||
feats = self.ocr_distri_head(feats, context)
|
||||
|
||||
out = self.cls_head(feats)
|
||||
|
||||
if self.is_export_onnx or self.is_export_full_onnx:
|
||||
x_class = F.interpolate(
|
||||
out, size=(raw_h, raw_w), mode='bilinear', align_corners=True)
|
||||
x_class = torch.softmax(x_class, dim=1)
|
||||
_, x_class = torch.max(x_class, dim=1, keepdim=True)
|
||||
x_class = x_class.float()
|
||||
return x_class
|
||||
else:
|
||||
out_aux_seg = [
|
||||
out_aux, out
|
||||
] # out_aux: 5, 2, 160, 160(HRNet origin res); out: 5, 2, 320, 320(HRNet res+tail+aspp+ocr)
|
||||
return out_aux_seg
|
||||
|
||||
|
||||
def get_seg_model(cfg, **kwargs):
|
||||
model = HrnetSuperAndOcr(cfg, **kwargs)
|
||||
model.init_weights(cfg.MODEL.PRETRAINED)
|
||||
|
||||
return model
|
||||
229
modelscope/models/cv/image_skychange/ptsemseg/unet.py
Normal file
229
modelscope/models/cv/image_skychange/ptsemseg/unet.py
Normal file
@@ -0,0 +1,229 @@
|
||||
# Copyright 2021-2022 The Alibaba Vision Team Authors. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .BlockModules import ASPP
|
||||
|
||||
|
||||
class Conv2DBatchNormRelu(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_filters,
|
||||
k_size,
|
||||
stride,
|
||||
padding,
|
||||
bias=True,
|
||||
dilation=1,
|
||||
with_bn=True,
|
||||
with_relu=True):
|
||||
super(Conv2DBatchNormRelu, self).__init__()
|
||||
|
||||
conv_mod = nn.Conv2d(
|
||||
int(in_channels),
|
||||
int(n_filters),
|
||||
kernel_size=k_size,
|
||||
padding=padding,
|
||||
stride=stride,
|
||||
bias=bias,
|
||||
dilation=dilation,
|
||||
)
|
||||
|
||||
if with_bn:
|
||||
if with_relu:
|
||||
self.cbr_unit = nn.Sequential(conv_mod,
|
||||
nn.BatchNorm2d(int(n_filters)),
|
||||
nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.cbr_unit = nn.Sequential(conv_mod,
|
||||
nn.BatchNorm2d(int(n_filters)))
|
||||
else:
|
||||
if with_relu:
|
||||
self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True))
|
||||
else:
|
||||
self.cbr_unit = nn.Sequential(conv_mod)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.cbr_unit(inputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class SegnetDown2(nn.Module):
|
||||
|
||||
def __init__(self, in_size, out_size):
|
||||
super(SegnetDown2, self).__init__()
|
||||
self.conv1 = Conv2DBatchNormRelu(
|
||||
in_size, out_size, k_size=3, stride=1, padding=1)
|
||||
self.conv2 = Conv2DBatchNormRelu(
|
||||
out_size, out_size, k_size=3, stride=1, padding=1)
|
||||
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.conv1(inputs)
|
||||
outputs = self.conv2(outputs)
|
||||
unpooled_shape = outputs.size()
|
||||
outputs, indices = self.maxpool_with_argmax(outputs)
|
||||
return outputs, indices, unpooled_shape
|
||||
|
||||
|
||||
class SegnetDown3(nn.Module):
|
||||
|
||||
def __init__(self, in_size, out_size):
|
||||
super(SegnetDown3, self).__init__()
|
||||
self.conv1 = Conv2DBatchNormRelu(
|
||||
in_size, out_size, k_size=3, stride=1, padding=1)
|
||||
self.conv2 = Conv2DBatchNormRelu(
|
||||
out_size, out_size, k_size=3, stride=1, padding=1)
|
||||
self.conv3 = Conv2DBatchNormRelu(
|
||||
out_size, out_size, k_size=3, stride=1, padding=1)
|
||||
self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True)
|
||||
|
||||
def forward(self, inputs):
|
||||
outputs = self.conv1(inputs)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = self.conv3(outputs)
|
||||
unpooled_shape = outputs.size()
|
||||
outputs, indices = self.maxpool_with_argmax(outputs)
|
||||
return outputs, indices, unpooled_shape
|
||||
|
||||
|
||||
class SegnetUp1(nn.Module):
|
||||
|
||||
def __init__(self, in_size, out_size):
|
||||
super(SegnetUp1, self).__init__()
|
||||
self.unpool = nn.MaxUnpool2d(2, 2)
|
||||
self.conv = Conv2DBatchNormRelu(
|
||||
in_size, out_size, k_size=5, stride=1, padding=2, with_relu=False)
|
||||
|
||||
def forward(self, inputs, indices, output_shape):
|
||||
outputs = self.unpool(
|
||||
input=inputs, indices=indices, output_size=output_shape)
|
||||
outputs = self.conv(outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class Unet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
n_classes=2,
|
||||
in_channels=4,
|
||||
is_unpooling=True,
|
||||
pretrain=True,
|
||||
**kwargs):
|
||||
super(Unet, self).__init__()
|
||||
print('Load Unet')
|
||||
self.in_channels = in_channels
|
||||
self.is_unpooling = is_unpooling
|
||||
self.pretrain = pretrain
|
||||
self.is_contain_aspp = True if 'aspp' in kwargs else False
|
||||
|
||||
if self.is_contain_aspp:
|
||||
aspp_param = kwargs['aspp']
|
||||
self.aspp_layer = ASPP(
|
||||
inplanes=128,
|
||||
outplanes=aspp_param['outplanes'],
|
||||
dilations=aspp_param['dilations'],
|
||||
drop_rate=aspp_param['drop_rate'])
|
||||
self.aspp_channels = aspp_param['outplanes']
|
||||
|
||||
self.down1 = SegnetDown2(self.in_channels, 64)
|
||||
self.down2 = SegnetDown2(64, 128)
|
||||
self.down3 = SegnetDown3(128, 256)
|
||||
self.down4 = SegnetDown3(256, 512)
|
||||
self.down5 = SegnetDown3(512, 512)
|
||||
|
||||
self.up5 = SegnetUp1(512, 512)
|
||||
self.up4 = SegnetUp1(512, 256)
|
||||
self.up3 = SegnetUp1(256, 128)
|
||||
|
||||
if self.is_contain_aspp:
|
||||
self.conv_1x1_aspp = Conv2DBatchNormRelu(
|
||||
128 + self.aspp_channels,
|
||||
128,
|
||||
k_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
with_relu=False)
|
||||
|
||||
self.up2 = SegnetUp1(128, 64)
|
||||
self.up1 = SegnetUp1(64, n_classes)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
if self.pretrain:
|
||||
import torchvision.models as models
|
||||
vgg16 = models.vgg16()
|
||||
self.init_vgg16_params(vgg16)
|
||||
|
||||
def forward(self, inputs): # [1, 4, 1346, 1152] [2, 4, 1280, 1280]
|
||||
# inputs: [N, 4, 320, 320]
|
||||
# outputs, indices, unpooled_shape
|
||||
down1, indices_1, unpool_shape1 = self.down1(
|
||||
inputs) # [1, 64, 673, 576] [2, 64, 640, 640]
|
||||
down2, indices_2, unpool_shape2 = self.down2(
|
||||
down1) # [1, 128, 336, 288] [2, 128, 320, 320]
|
||||
down3, indices_3, unpool_shape3 = self.down3(
|
||||
down2) # [1, 256, 168, 144] [2, 256, 160, 160]
|
||||
torch.cuda.empty_cache()
|
||||
if self.is_contain_aspp: # batchsize can not be 1
|
||||
aspp_output = self.aspp_layer(down2)
|
||||
|
||||
down4, indices_4, unpool_shape4 = self.down4(
|
||||
down3) # [1, 512, 84, 72] [2, 512, 80, 80]
|
||||
down5, indices_5, unpool_shape5 = self.down5(
|
||||
down4) # [1, 512, 42, 36] [2, 512, 80, 80]
|
||||
torch.cuda.empty_cache()
|
||||
up5 = self.up5(down5, indices_5,
|
||||
unpool_shape5) # [1, 512, 84, 72] [2, 512, 80, 80]
|
||||
up4 = self.up4(up5, indices_4,
|
||||
unpool_shape4) # [1, 256, 168, 144] [2, 256, 160, 160]
|
||||
torch.cuda.empty_cache()
|
||||
up3 = self.up3(
|
||||
up4, indices_3,
|
||||
unpool_shape3) # [1, 128, 336, 288] [2, 128, 320, 320]
|
||||
if self.is_contain_aspp:
|
||||
up3 = torch.cat([up3, aspp_output], 1) # [2, 256, 320, 320]
|
||||
up3 = self.conv_1x1_aspp(up3) # [2, 128, 320, 320]
|
||||
|
||||
up2 = self.up2(
|
||||
up3, indices_2,
|
||||
unpool_shape2) # [1, 64, 673, 576] indices_2: [2, 128, 320, 320]
|
||||
up1 = self.up1(up2, indices_1, unpool_shape1) # [1, 1, 1346, 1152]
|
||||
|
||||
x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320]
|
||||
x = self.sigmoid(x)
|
||||
|
||||
return x # [2, 1280, 1280]
|
||||
|
||||
def init_vgg16_params(self, vgg16):
|
||||
blocks = [self.down1, self.down2, self.down3, self.down4, self.down5]
|
||||
|
||||
features = list(vgg16.features.children())
|
||||
|
||||
vgg_layers = []
|
||||
for _layer in features:
|
||||
if isinstance(_layer, nn.Conv2d):
|
||||
vgg_layers.append(_layer)
|
||||
|
||||
merged_layers = []
|
||||
for idx, conv_block in enumerate(blocks):
|
||||
if idx < 2:
|
||||
units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit]
|
||||
else:
|
||||
units = [
|
||||
conv_block.conv1.cbr_unit,
|
||||
conv_block.conv2.cbr_unit,
|
||||
conv_block.conv3.cbr_unit,
|
||||
]
|
||||
for _unit in units:
|
||||
for _layer in _unit:
|
||||
if isinstance(_layer, nn.Conv2d):
|
||||
merged_layers.append(_layer)
|
||||
|
||||
assert len(vgg_layers) == len(merged_layers)
|
||||
|
||||
for l1, l2 in zip(vgg_layers, merged_layers):
|
||||
if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d):
|
||||
if l1.weight.size() == l2.weight.size() and l1.bias.size(
|
||||
) == l2.bias.size():
|
||||
l2.weight.data = l1.weight.data
|
||||
l2.bias.data = l1.bias.data
|
||||
310
modelscope/models/cv/image_skychange/skychange.py
Normal file
310
modelscope/models/cv/image_skychange/skychange.py
Normal file
@@ -0,0 +1,310 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numbers
|
||||
import os
|
||||
import pdb
|
||||
from collections import deque
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
torch.backends.cudnn.enabled = True
|
||||
|
||||
IMAGE_MAX_DIM = 3000
|
||||
IMAGE_MIN_DIM = 50
|
||||
IMAGE_MAX_RATIO = 10.0
|
||||
IMAGE_BLENDER_MASK_RESIZE_SCALE = 10.0
|
||||
IMAGE_BLENDER_INNER_RECT_MAX_DIM = 256
|
||||
IMAGE_BLENDER_DILATE_KERNEL_SIZE = 7
|
||||
IMAGE_BLENDER_VALID_MASK_THRESHOLD = 100
|
||||
IMAGE_BLENDER_MIN_VALID_SKY_AREA = 100
|
||||
IMAGE_BLENDER_MIN_RESIZE_DIM = 10
|
||||
IMAGE_BLENDER_BLUR_KERNEL_SIZE = 5
|
||||
|
||||
|
||||
def extract_sky_image(in_sky_image, in_sky_mask):
|
||||
scale = 1.0
|
||||
resize_mask = in_sky_mask.copy()
|
||||
|
||||
rows, cols = resize_mask.shape[0:2]
|
||||
# src size: (512, 640), target size: (256,256), then scale to size (256, 320)
|
||||
if (rows > IMAGE_BLENDER_INNER_RECT_MAX_DIM
|
||||
or cols > IMAGE_BLENDER_INNER_RECT_MAX_DIM):
|
||||
height_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(rows)
|
||||
width_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(cols)
|
||||
scale = height_scale if height_scale > width_scale else width_scale
|
||||
new_size = (max(int(cols * scale), 1), max(int(rows * scale),
|
||||
1)) # w, h
|
||||
resize_mask = cv2.resize(resize_mask, new_size, cv2.INTER_LINEAR)
|
||||
|
||||
kernelSize = max(3, int(scale * IMAGE_BLENDER_DILATE_KERNEL_SIZE + 0.5))
|
||||
|
||||
element = cv2.getStructuringElement(cv2.MORPH_RECT,
|
||||
(kernelSize, kernelSize))
|
||||
resize_mask = cv2.morphologyEx(resize_mask, cv2.MORPH_CLOSE, element)
|
||||
|
||||
max_inner_rect, area = get_max_inner_rect(
|
||||
resize_mask, IMAGE_BLENDER_VALID_MASK_THRESHOLD, True)
|
||||
|
||||
if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
|
||||
raise Exception(
|
||||
'[extractSkyImage]failed!! Valid sky region is too small')
|
||||
|
||||
scale = 1.0 / scale
|
||||
# max_inner_rect: left top(x,y), right bottome(x,y); raw_inner_rect:left top x,y,w(of bbox),h(of bbox)
|
||||
raw_inner_rect = scale_rect(max_inner_rect, in_sky_mask, scale)
|
||||
out_sky_image = in_sky_image[raw_inner_rect[1]:raw_inner_rect[1]
|
||||
+ raw_inner_rect[3] + 1,
|
||||
raw_inner_rect[0]:raw_inner_rect[0]
|
||||
+ raw_inner_rect[2] + 1, ].copy()
|
||||
return out_sky_image
|
||||
|
||||
|
||||
def blend(scene_image, scene_mask, sky_image, sky_mask, inBlendLevelNum=10):
|
||||
if torch.cuda.is_available():
|
||||
scene_image = scene_image.cpu().numpy()
|
||||
sky_image = sky_image.cpu().numpy()
|
||||
else:
|
||||
scene_image = scene_image.numpy()
|
||||
sky_image = sky_image.numpy()
|
||||
sky_image_h, sky_image_w = sky_image.shape[0:2]
|
||||
sky_mask_h, sky_mask_w = sky_mask.shape[0:2]
|
||||
|
||||
scene_image_h, scene_image_w = scene_image.shape[0:2]
|
||||
scene_mask_h, scene_mask_w = scene_mask.shape[0:2]
|
||||
|
||||
if sky_image_h != sky_mask_h or sky_image_w != sky_mask_w:
|
||||
raise Exception(
|
||||
'[blend]failed!! sky_image shape not equal with sky_image_mask shape'
|
||||
)
|
||||
|
||||
if scene_image_h != scene_mask_h or scene_image_w != scene_mask_w:
|
||||
raise Exception(
|
||||
'[blend]failed!! scene_image shape not equal with scene_image_mask shape'
|
||||
)
|
||||
|
||||
valid_sky_image = extract_sky_image(sky_image, sky_mask)
|
||||
out_blend_image = blend_merge(scene_image, scene_mask, valid_sky_image,
|
||||
inBlendLevelNum)
|
||||
return out_blend_image
|
||||
|
||||
|
||||
def get_max_inner_rect(in_image_mask, in_alpha_threshold, is_bigger_valid):
|
||||
res = 0
|
||||
row, col = in_image_mask.shape[0:2]
|
||||
i0, j0, i1, j1 = 0, 0, 0, 0
|
||||
height = [0] * (col + 1)
|
||||
|
||||
for i in range(0, row):
|
||||
s = deque()
|
||||
for j in range(0, col + 1):
|
||||
if j < col:
|
||||
if is_bigger_valid:
|
||||
height[j] = (
|
||||
height[j]
|
||||
+ 1 if in_image_mask[i, j] > in_alpha_threshold else 0)
|
||||
else:
|
||||
height[j] = (
|
||||
height[j] + 1
|
||||
if in_image_mask[i, j] <= in_alpha_threshold else 0)
|
||||
|
||||
while len(s) != 0 and height[s[-1]] >= height[j]:
|
||||
cur = s[-1]
|
||||
s.pop()
|
||||
_h = height[cur]
|
||||
_w = j if len(s) == 0 else j - s[-1] - 1
|
||||
curArea = _h * _w
|
||||
if curArea > res:
|
||||
res = curArea
|
||||
i1 = i
|
||||
i0 = i1 - _h + 1
|
||||
j1 = j - 1
|
||||
j0 = j1 - _w + 1
|
||||
s.append(j)
|
||||
|
||||
out_rect = (
|
||||
j0,
|
||||
i0,
|
||||
j1 - j0 + 1,
|
||||
i1 - i0 + 1,
|
||||
)
|
||||
return out_rect, res
|
||||
|
||||
|
||||
def scale_rect(in_rect, in_image_size, in_scale):
|
||||
tlX = int(in_rect[0] * in_scale + 0.5)
|
||||
tlY = int(in_rect[1] * in_scale + 0.5)
|
||||
in_image_size_h, in_image_size_w = in_image_size.shape[0:2]
|
||||
brX = min(int(in_rect[2] * in_scale + 0.5), in_image_size_w)
|
||||
brY = min(int(in_rect[3] * in_scale + 0.5), in_image_size_h)
|
||||
out_rect = (tlX, tlY, brX - tlX, brY - tlY)
|
||||
return out_rect
|
||||
|
||||
|
||||
def get_fast_valid_rect(in_mask, in_threshold=0):
|
||||
# mask: np.array [0~1]
|
||||
in_mask = in_mask > in_threshold
|
||||
locations = cv2.findNonZero(in_mask.astype(np.uint8))
|
||||
output_rect = cv2.boundingRect(locations) # x,y,w,h
|
||||
return output_rect
|
||||
|
||||
|
||||
def min_size_match(in_image, in_min_size, type=cv2.INTER_LINEAR):
|
||||
resize_image = in_image.copy()
|
||||
width, height = in_min_size
|
||||
resize_img_height, resize_img_width = in_image.shape[0:2]
|
||||
height_scale = height / resize_img_height
|
||||
widht_scale = width / resize_img_width
|
||||
scale = height_scale if height_scale > widht_scale else widht_scale
|
||||
new_size = (
|
||||
max(int(resize_img_width * scale + 0.5), 1),
|
||||
max(int(resize_img_height * scale + 0.5), 1),
|
||||
)
|
||||
|
||||
resize_image = cv2.resize(resize_image, new_size, 0, 0, type)
|
||||
return resize_image
|
||||
|
||||
|
||||
def center_crop(in_image, in_size):
|
||||
in_size_w, in_size_h = in_size
|
||||
in_image_h, in_image_w = in_image.shape[0:2]
|
||||
|
||||
half_height = (in_image_h - in_size_h) // 2
|
||||
half_width = (in_image_w - in_size_w) // 2
|
||||
|
||||
out_crop_image = in_image.copy()
|
||||
out_crop_image = out_crop_image[half_height:half_height + in_size_h,
|
||||
half_width:half_width + in_size_w]
|
||||
return out_crop_image
|
||||
|
||||
|
||||
def safe_roi_pad(in_pad_image, in_rect, out_base_image):
|
||||
in_rect_x, in_rect_y, in_rect_w, in_rect_h = in_rect
|
||||
|
||||
if in_rect_x < 0 or in_rect_y < 0 or in_rect_w <= 0 or in_rect_h <= 0:
|
||||
raise Exception('[safe_roi_pad] Failed!! x,y,w,h of rect are illegal')
|
||||
|
||||
if in_rect_w != in_pad_image.shape[1] or in_rect_h != in_pad_image.shape[0]:
|
||||
raise Exception('[safe_roi_pad] Failed!!')
|
||||
|
||||
if (in_rect_x + in_rect_w > out_base_image.shape[1]
|
||||
or in_rect_y + in_rect_h > out_base_image.shape[0]):
|
||||
raise Exception('[safe_roi_pad] Failed!!')
|
||||
|
||||
out_base_image[in_rect_y:in_rect_y + in_rect_h,
|
||||
in_rect_x:in_rect_x + in_rect_w] = in_pad_image
|
||||
|
||||
|
||||
def merge_image(in_base_image, in_merge_image, in_merge_mask, in_point):
|
||||
if in_merge_image.shape[0:2] != in_merge_mask.shape[0:2]:
|
||||
raise Exception(
|
||||
'[merge_image] Failed!! in_merge_image.shape != in_merge_mask.shape!!'
|
||||
)
|
||||
|
||||
in_point_x, in_point_y = in_point
|
||||
in_merge_image_rows, in_merge_image_cols = in_merge_image.shape[0:2]
|
||||
in_base_image_rows, in_base_image_cols = in_base_image.shape[0:2]
|
||||
|
||||
if (in_point_x + in_merge_image_cols > in_base_image_cols
|
||||
or in_point_y + in_merge_image_rows > in_base_image_rows):
|
||||
raise Exception(
|
||||
'[merge_image] Failed!! merge_image:image rect not in image')
|
||||
|
||||
base_roi_image = in_base_image[in_point_y:in_point_y + in_merge_image_rows,
|
||||
in_point_x:in_point_x
|
||||
+ in_merge_image_cols, ]
|
||||
|
||||
merge_image = in_merge_image.copy()
|
||||
merge_alpha = in_merge_mask.copy()
|
||||
base_roi_image = np.float32(base_roi_image)
|
||||
merge_alpha = np.repeat(merge_alpha[:, :, np.newaxis], 3, axis=2)
|
||||
merge_alpha = merge_alpha / 255.0
|
||||
|
||||
base_roi_image = (
|
||||
1 - merge_alpha) * base_roi_image + merge_alpha * merge_image
|
||||
base_roi_image = np.clip(base_roi_image, 0, 255)
|
||||
base_roi_image = base_roi_image.astype('uint8')
|
||||
|
||||
roi_rect = (in_point_x, in_point_y, in_merge_image_cols,
|
||||
in_merge_image_rows)
|
||||
safe_roi_pad(base_roi_image, roi_rect, in_base_image)
|
||||
return in_base_image
|
||||
|
||||
|
||||
def blend_merge(in_scene_image,
|
||||
in_scene_mask,
|
||||
in_valid_sky_image,
|
||||
inBlendLevelNum=5):
|
||||
scene_sky_rect = get_fast_valid_rect(in_scene_mask, 1)
|
||||
area = scene_sky_rect[2] * scene_sky_rect[3]
|
||||
|
||||
if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
|
||||
raise Exception(
|
||||
'[blend_merge] Failed!! Scene Image Valid sky region is too small')
|
||||
|
||||
valid_sky_image = min_size_match(in_valid_sky_image, scene_sky_rect[2:])
|
||||
valid_sky_image = center_crop(valid_sky_image, scene_sky_rect[2:])
|
||||
|
||||
# resizeSceneMask
|
||||
sky_size = (
|
||||
max(
|
||||
int(in_scene_mask.shape[1] * IMAGE_BLENDER_MASK_RESIZE_SCALE
|
||||
+ 0.5),
|
||||
IMAGE_BLENDER_MIN_RESIZE_DIM,
|
||||
),
|
||||
max(
|
||||
int(in_scene_mask.shape[0] * IMAGE_BLENDER_MASK_RESIZE_SCALE
|
||||
+ 0.5),
|
||||
IMAGE_BLENDER_MIN_RESIZE_DIM,
|
||||
),
|
||||
)
|
||||
|
||||
resize_scene_mask = cv2.resize(in_scene_mask, sky_size, cv2.INTER_LINEAR)
|
||||
resize_scene_mask = cv2.blur(
|
||||
resize_scene_mask,
|
||||
(IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE),
|
||||
)
|
||||
|
||||
element = cv2.getStructuringElement(
|
||||
cv2.MORPH_RECT,
|
||||
(IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE))
|
||||
sky_mask = cv2.dilate(resize_scene_mask, element) # enlarge sky region
|
||||
scene_mask = cv2.erode(resize_scene_mask, element) # enlarge scene region
|
||||
scene_mask = 255 - scene_mask
|
||||
|
||||
sky_mask = cv2.resize(sky_mask, in_scene_mask.shape[0:2][::-1])
|
||||
scene_mask = cv2.resize(scene_mask, in_scene_mask.shape[0:2][::-1])
|
||||
|
||||
x, y, w, h = scene_sky_rect
|
||||
valid_sky_mask = sky_mask[y:y + h, x:x + w]
|
||||
|
||||
pano_sky_image = in_scene_image.copy()
|
||||
|
||||
pano_sky_image = merge_image(pano_sky_image, valid_sky_image,
|
||||
valid_sky_mask, scene_sky_rect[0:2])
|
||||
blend_images = []
|
||||
blend_images.append(in_scene_image)
|
||||
blend_images.append(pano_sky_image)
|
||||
|
||||
blend_masks = []
|
||||
blend_masks.append(scene_mask.astype(np.uint8))
|
||||
blend_masks.append(sky_mask.astype(np.uint8))
|
||||
|
||||
panorama_rect = (0, 0, in_scene_image.shape[1], in_scene_image.shape[0])
|
||||
|
||||
blender = cv2.detail_MultiBandBlender(1, inBlendLevelNum)
|
||||
blender.prepare(panorama_rect)
|
||||
|
||||
for i in range(0, len(blend_images)):
|
||||
blender.feed(blend_images[i], blend_masks[i], (0, 0))
|
||||
pano_mask = (
|
||||
np.ones(
|
||||
(in_scene_image.shape[1], in_scene_image.shape[0]), dtype='uint8')
|
||||
* 255)
|
||||
out_blend_image = np.zeros_like(in_scene_image)
|
||||
result = blender.blend(out_blend_image, pano_mask)
|
||||
return result[0]
|
||||
199
modelscope/models/cv/image_skychange/skychange_model.py
Normal file
199
modelscope/models/cv/image_skychange/skychange_model.py
Normal file
@@ -0,0 +1,199 @@
|
||||
import math
|
||||
import os
|
||||
import pdb
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
import cv2
|
||||
import json
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.base import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .ptsemseg.hrnet_super_and_ocr import HrnetSuperAndOcr
|
||||
from .ptsemseg.unet import Unet
|
||||
from .skychange import blend
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_skychange, module_name=Models.image_skychange)
|
||||
class ImageSkychange(TorchModel):
|
||||
|
||||
def __init__(self, model_dir, refine_cfg, coarse_cfg, *args, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
model_dir (str): model directory to initialize some resource.
|
||||
refine_cfg: configuration of refine model.
|
||||
coarse_cfg: configuration of coarse model.
|
||||
"""
|
||||
super().__init__(model_dir=model_dir, *args, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
self.device = torch.device('cuda')
|
||||
logger.info('Use GPU: {}'.format(self.device))
|
||||
else:
|
||||
self.device = torch.device('cpu')
|
||||
logger.info('Use CPU: {}'.format(self.device))
|
||||
|
||||
coarse_model_path = '{}/{}'.format(model_dir,
|
||||
ModelFile.TORCH_MODEL_FILE)
|
||||
refine_model_path = '{}/{}'.format(model_dir,
|
||||
'unet_sky_matting_final_model.pkl')
|
||||
|
||||
logger.info(
|
||||
'####################### load refine models ################################'
|
||||
)
|
||||
self.refine_model = Unet(**refine_cfg['Model'])
|
||||
self.load_model(self.refine_model, refine_model_path)
|
||||
self.refine_model.eval()
|
||||
logger.info(
|
||||
'####################### load refine models done ############################'
|
||||
)
|
||||
|
||||
logger.info(
|
||||
'####################### load coarse models ################################'
|
||||
)
|
||||
self.coarse_model = HrnetSuperAndOcr(**coarse_cfg['Model'])
|
||||
self.load_model(self.coarse_model, coarse_model_path)
|
||||
self.coarse_model.eval()
|
||||
logger.info(
|
||||
'####################### load coarse models done ############################'
|
||||
)
|
||||
|
||||
def load_model(self, seg_model, input_model_path):
|
||||
if not os.path.isfile(input_model_path):
|
||||
logger.error(
|
||||
'[checkModelPath]:model path dose not exits!!! model Path:'
|
||||
+ input_model_path)
|
||||
raise Exception('[checkModelPath]:model path dose not exits!')
|
||||
|
||||
if torch.cuda.is_available():
|
||||
checkpoint = torch.load(input_model_path)
|
||||
model_state = self.convert_state_dict(checkpoint['model_state'])
|
||||
seg_model.load_state_dict(model_state)
|
||||
seg_model.to(self.device)
|
||||
else:
|
||||
checkpoint = torch.load(input_model_path, map_location='cpu')
|
||||
model_state = self.convert_state_dict(checkpoint['model_state'])
|
||||
seg_model.load_state_dict(model_state)
|
||||
|
||||
def convert_state_dict(self, state_dict):
|
||||
"""Converts a state dict saved from a dataParallel module to normal
|
||||
module state_dict inplace
|
||||
:param state_dict is the loaded DataParallel model_state
|
||||
"""
|
||||
if not next(iter(state_dict)).startswith('module.'):
|
||||
return state_dict # abort if dict is not a DataParallel model_state
|
||||
new_state_dict = OrderedDict()
|
||||
|
||||
split_index = 0
|
||||
for cur_key, cur_value in state_dict.items():
|
||||
if cur_key.startswith('module.model'):
|
||||
split_index = 13
|
||||
elif cur_key.startswith('module'):
|
||||
split_index = 7
|
||||
|
||||
break
|
||||
|
||||
for k, v in state_dict.items():
|
||||
name = k[split_index:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sky_image: torch.Tensor,
|
||||
sky_image_refine: torch.Tensor,
|
||||
scene_image: torch.Tensor,
|
||||
scene_image_refine: torch.Tensor,
|
||||
img_metas: Dict[str, Any],
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
sky_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h', w'].
|
||||
sky_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w].
|
||||
scene_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h, w].
|
||||
scene_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w].
|
||||
img_metas (`Dict[str, Any]`): image meta info.
|
||||
Return:
|
||||
`IMAGE: shape is [h, w, 3] (0~255)`
|
||||
"""
|
||||
start = time.time()
|
||||
sky_img_metas, scene_img_metas, input_size = img_metas[
|
||||
'sky_img_metas'], img_metas['scene_img_metas'], img_metas[
|
||||
'input_size']
|
||||
sky_mask = self.inference_mask(sky_image_refine, sky_img_metas,
|
||||
input_size)
|
||||
scene_mask = self.inference_mask(scene_image_refine, scene_img_metas,
|
||||
input_size)
|
||||
end = time.time()
|
||||
logger.info(
|
||||
'Time of inferencing mask of sky and scene images:{}'.format(
|
||||
end - start))
|
||||
start = time.time()
|
||||
scene_mask = scene_mask * 255
|
||||
sky_mask = sky_mask * 255
|
||||
res = blend(scene_image, scene_mask, sky_image, sky_mask)
|
||||
end = time.time()
|
||||
logger.info('Time of blending: {}'.format(end - start))
|
||||
return res
|
||||
|
||||
@torch.no_grad()
|
||||
def inference_mask(self, img, img_metas, input_size):
|
||||
self.eval()
|
||||
raw_h, raw_w = img_metas['ori_shape']
|
||||
pad_direction = img_metas['pad_direction']
|
||||
coarse_input_size = input_size['coarse_input_size']
|
||||
refine_input_size = input_size['refine_input_size']
|
||||
h, w = img_metas['refine_shape']
|
||||
resize_images = F.interpolate(
|
||||
img, coarse_input_size, mode='bilinear', align_corners=True)
|
||||
# get coarse result
|
||||
pred_scores = self.coarse_model(resize_images)
|
||||
if isinstance(pred_scores, (tuple, list)):
|
||||
pred_scores = pred_scores[-1]
|
||||
score = F.interpolate(
|
||||
input=pred_scores,
|
||||
size=refine_input_size,
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
)
|
||||
_, coarse_pred = torch.max(score, dim=1) # [B, h, w]
|
||||
coarse_pred = coarse_pred.unsqueeze(1).type(img.dtype)
|
||||
img = torch.cat([img, coarse_pred], dim=1) # [B, c=4, h, w]
|
||||
del resize_images
|
||||
del pred_scores
|
||||
del score
|
||||
del coarse_pred
|
||||
torch.cuda.empty_cache()
|
||||
cur_scores = self.refine_model(img)
|
||||
del img
|
||||
torch.cuda.empty_cache()
|
||||
cur_scores = torch.clip(cur_scores, 0, 1)
|
||||
cur_scores = cur_scores.detach().cpu().numpy()[0]
|
||||
|
||||
# resize if cur_scores shape are not compatible with origin image shape
|
||||
ph, pw = cur_scores.shape
|
||||
if ph != h or pw != w:
|
||||
cur_scores = F.interpolate(
|
||||
input=cur_scores,
|
||||
size=(h, w),
|
||||
mode='nearest',
|
||||
align_corners=True)
|
||||
# unpad to get valid area and resize to origin size
|
||||
valid_cur_pred = cur_scores[pad_direction[1]:refine_input_size[0]
|
||||
- pad_direction[3],
|
||||
pad_direction[0]:refine_input_size[1]
|
||||
- pad_direction[2], ]
|
||||
valid_cur_pred = cv2.resize(valid_cur_pred, (raw_w, raw_h))
|
||||
del cur_scores
|
||||
torch.cuda.empty_cache()
|
||||
print('get refine mask done')
|
||||
return valid_cur_pred
|
||||
@@ -836,6 +836,11 @@ TASK_OUTPUTS = {
|
||||
# }
|
||||
Tasks.product_segmentation: [OutputKeys.MASKS],
|
||||
|
||||
# image_skychange result for a single sample
|
||||
# {
|
||||
# "output_img": np.ndarray with shape [height, width, 3]
|
||||
# }
|
||||
Tasks.image_skychange: [OutputKeys.OUTPUT_IMG],
|
||||
# {
|
||||
# 'scores': [0.1, 0.2, 0.3, ...]
|
||||
# }
|
||||
|
||||
@@ -101,6 +101,10 @@ TASK_INPUTS = {
|
||||
'img': InputType.IMAGE,
|
||||
'mask': InputType.IMAGE,
|
||||
},
|
||||
Tasks.image_skychange: {
|
||||
'sky_image': InputType.IMAGE,
|
||||
'scene_image': InputType.IMAGE,
|
||||
},
|
||||
|
||||
# image generation task result for a single image
|
||||
Tasks.image_to_image_generation:
|
||||
|
||||
@@ -223,6 +223,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_swin-t_referring_video-object-segmentation'),
|
||||
Tasks.video_summarization: (Pipelines.video_summarization,
|
||||
'damo/cv_googlenet_pgl-video-summarization'),
|
||||
Tasks.image_skychange: (Pipelines.image_skychange,
|
||||
'damo/cv_hrnetocr_skychange'),
|
||||
Tasks.translation_evaluation:
|
||||
(Pipelines.translation_evaluation,
|
||||
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
|
||||
|
||||
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
||||
from .hand_static_pipeline import HandStaticPipeline
|
||||
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
|
||||
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
|
||||
from .image_skychange_pipeline import ImageSkychangePipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
|
||||
else:
|
||||
@@ -155,6 +156,7 @@ else:
|
||||
'language_guided_video_summarization_pipeline': [
|
||||
'LanguageGuidedVideoSummarizationPipeline'
|
||||
],
|
||||
'image_skychange_pipeline': ['ImageSkychangePipeline'],
|
||||
'video_object_segmentation_pipeline': [
|
||||
'VideoObjectSegmentationPipeline'
|
||||
],
|
||||
|
||||
63
modelscope/pipelines/cv/image_skychange_pipeline.py
Normal file
63
modelscope/pipelines/cv/image_skychange_pipeline.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import pdb
|
||||
import time
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import PIL
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.image_skychange import ImageSkyChangePreprocessor
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Model, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_skychange, module_name=Pipelines.image_skychange)
|
||||
class ImageSkychangePipeline(Pipeline):
|
||||
""" Image Sky Change Pipeline. Given two images(sky_image and scene_image),
|
||||
pipeline will replace the sky style of sky_image with the sky style of scene_image.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> detector = pipeline('image-skychange', 'damo/cv_hrnetocr_skychange')
|
||||
>>> detector({
|
||||
'sky_image': 'sky_image.jpg', # sky_image path (str)
|
||||
'scene_image': 'scene_image.jpg', # scene_image path (str)
|
||||
})
|
||||
{
|
||||
"output_img": [H * W * 3] 0~255, we can use cv2.imwrite to save output_img as an image.
|
||||
}
|
||||
>>> #
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image sky change pipeline for image editing
|
||||
Args:
|
||||
model (`str` or `Model`): model_id on modelscope hub
|
||||
preprocessor(`Preprocessor`, *optional*, defaults to None): `ImageSkyChangePreprocessor`.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
if not isinstance(self.model, Model):
|
||||
logger.error('model object is not initialized.')
|
||||
raise Exception('model object is not initialized.')
|
||||
if self.preprocessor is None:
|
||||
self.preprocessor = ImageSkyChangePreprocessor()
|
||||
logger.info('load model done')
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
res = self.model.forward(**input)
|
||||
return {OutputKeys.OUTPUT_IMG: res}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return inputs
|
||||
@@ -67,6 +67,7 @@ class CVTasks(object):
|
||||
image_denoising = 'image-denoising'
|
||||
image_portrait_enhancement = 'image-portrait-enhancement'
|
||||
image_inpainting = 'image-inpainting'
|
||||
image_skychange = 'image-skychange'
|
||||
|
||||
# image generation
|
||||
image_to_image_translation = 'image-to-image-translation'
|
||||
|
||||
48
tests/pipelines/test_image_skychange.py
Normal file
48
tests/pipelines/test_image_skychange.py
Normal file
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path as osp
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
|
||||
import modelscope
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.base import Pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
print(modelscope.version.__release_datetime__)
|
||||
|
||||
|
||||
class ImageSkychangeTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.model = 'damo/cv_hrnetocr_skychange'
|
||||
self.sky_image = 'data/test/images/sky_image.jpg'
|
||||
self.scene_image = 'data/test/images/scene_image.jpg'
|
||||
self.input = {
|
||||
'sky_image': self.sky_image,
|
||||
'scene_image': self.scene_image,
|
||||
}
|
||||
|
||||
def pipeline_inference(self, pipeline: Pipeline, input: str):
|
||||
result = pipeline(input)
|
||||
if result is not None:
|
||||
cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG])
|
||||
print(f'Output written to {osp.abspath("result.png")}')
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_modelhub(self):
|
||||
image_skychange = pipeline(Tasks.image_skychange, model=self.model)
|
||||
self.pipeline_inference(image_skychange, self.input)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_modelhub_default_model(self):
|
||||
image_skychange = pipeline(Tasks.image_skychange)
|
||||
self.pipeline_inference(image_skychange, self.input)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -41,6 +41,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
||||
- test_image_matting.py
|
||||
- test_skin_retouching.py
|
||||
- test_table_recognition.py
|
||||
- test_image_skychange.py
|
||||
|
||||
envs:
|
||||
default: # default env, case not in other env will in default, pytorch.
|
||||
|
||||
Reference in New Issue
Block a user