From b36bb728693ec8de794f5e980745da117b1e7a3a Mon Sep 17 00:00:00 2001 From: "hannah.yh" Date: Wed, 21 Dec 2022 17:40:46 +0800 Subject: [PATCH] add image skychange Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/10947701 --- data/test/images/scene_image.jpg | 3 + data/test/images/sky_image.jpg | 3 + modelscope/metainfo.py | 3 + .../models/cv/image_skychange/__init__.py | 22 + .../models/cv/image_skychange/preprocessor.py | 245 +++++++ .../image_skychange/ptsemseg/BlockModules.py | 118 ++++ .../cv/image_skychange/ptsemseg/__init__.py | 0 .../ptsemseg/hrnet_backnone.py | 620 ++++++++++++++++++ .../ptsemseg/hrnet_super_and_ocr.py | 510 ++++++++++++++ .../cv/image_skychange/ptsemseg/unet.py | 229 +++++++ .../models/cv/image_skychange/skychange.py | 310 +++++++++ .../cv/image_skychange/skychange_model.py | 199 ++++++ modelscope/outputs/outputs.py | 5 + modelscope/pipeline_inputs.py | 4 + modelscope/pipelines/builder.py | 2 + modelscope/pipelines/cv/__init__.py | 2 + .../pipelines/cv/image_skychange_pipeline.py | 63 ++ modelscope/utils/constant.py | 1 + tests/pipelines/test_image_skychange.py | 48 ++ tests/run_config.yaml | 1 + 20 files changed, 2388 insertions(+) create mode 100644 data/test/images/scene_image.jpg create mode 100644 data/test/images/sky_image.jpg create mode 100644 modelscope/models/cv/image_skychange/__init__.py create mode 100644 modelscope/models/cv/image_skychange/preprocessor.py create mode 100644 modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py create mode 100644 modelscope/models/cv/image_skychange/ptsemseg/__init__.py create mode 100644 modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py create mode 100644 modelscope/models/cv/image_skychange/ptsemseg/hrnet_super_and_ocr.py create mode 100644 modelscope/models/cv/image_skychange/ptsemseg/unet.py create mode 100644 modelscope/models/cv/image_skychange/skychange.py create mode 100644 modelscope/models/cv/image_skychange/skychange_model.py create mode 100644 modelscope/pipelines/cv/image_skychange_pipeline.py create mode 100644 tests/pipelines/test_image_skychange.py diff --git a/data/test/images/scene_image.jpg b/data/test/images/scene_image.jpg new file mode 100644 index 00000000..d0ce5bee --- /dev/null +++ b/data/test/images/scene_image.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:260cd09f340b86007dd471cba742f82bae0fb5cfd4b8d87265bff5ad2c2c857f +size 652482 diff --git a/data/test/images/sky_image.jpg b/data/test/images/sky_image.jpg new file mode 100644 index 00000000..f00f9296 --- /dev/null +++ b/data/test/images/sky_image.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:679c86d5a82c9c1c4866b5e16b98a2128a57e3ea60f77d56e5f0fe79ab7d746f +size 505993 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 70fe52cc..45297fd7 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -57,6 +57,7 @@ class Models(object): face_emotion = 'face-emotion' product_segmentation = 'product-segmentation' image_body_reshaping = 'image-body-reshaping' + image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' video_object_segmentation = 'video-object-segmentation' @@ -243,6 +244,7 @@ class Pipelines(object): product_segmentation = 'product-segmentation' image_body_reshaping = 'flow-based-body-reshaping' referring_video_object_segmentation = 'referring-video-object-segmentation' + image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' video_object_segmentation = 'video-object-segmentation' @@ -389,6 +391,7 @@ class Preprocessors(object): movie_scene_segmentation_preprocessor = 'movie-scene-segmentation-preprocessor' image_classification_bypass_preprocessor = 'image-classification-bypass-preprocessor' object_detection_scrfd = 'object-detection-scrfd' + image_sky_change_preprocessor = 'image-sky-change-preprocessor' # nlp preprocessor sen_sim_tokenizer = 'sen-sim-tokenizer' diff --git a/modelscope/models/cv/image_skychange/__init__.py b/modelscope/models/cv/image_skychange/__init__.py new file mode 100644 index 00000000..955f26e6 --- /dev/null +++ b/modelscope/models/cv/image_skychange/__init__.py @@ -0,0 +1,22 @@ +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .skychange_model import ImageSkychange + from .preprocessor import ImageSkyChangePreprocessor + +else: + _import_structure = {'skychange_model': ['ImageSkychange']} + _import_structure = {'preprocessor': ['ImageSkyChangePreprocessor']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/image_skychange/preprocessor.py b/modelscope/models/cv/image_skychange/preprocessor.py new file mode 100644 index 00000000..570fb6be --- /dev/null +++ b/modelscope/models/cv/image_skychange/preprocessor.py @@ -0,0 +1,245 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import numbers +import pdb +from typing import Any, Dict, Union + +import cv2 +import json +import numpy as np +import torch +from torchvision import transforms + +from modelscope.metainfo import Preprocessors +from modelscope.preprocessors import Preprocessor +from modelscope.preprocessors.builder import PREPROCESSORS +from modelscope.preprocessors.image import LoadImage +from modelscope.utils.constant import Fields, ModeKeys + +_cv2_pad_to_str = { + 'constant': cv2.BORDER_CONSTANT, + 'edge': cv2.BORDER_REPLICATE, + 'reflect': cv2.BORDER_REFLECT_101, + 'symmetric': cv2.BORDER_REFLECT, +} + + +@PREPROCESSORS.register_module( + Fields.cv, module_name=Preprocessors.image_sky_change_preprocessor) +class ImageSkyChangePreprocessor(Preprocessor): + + def __init__(self, + model_dir: str = None, + mode: str = ModeKeys.INFERENCE, + coarse_model_width=640, + coarse_model_height=640, + refine_model_width=1280, + refine_model_height=1280, + mean_vec=[0.485, 0.456, 0.406], + std_vec=[0.229, 0.224, 0.225], + *args, + **kwargs): + """ + Args: + model_dir (str): model directory to initialize some resource. + mode: The mode for the preprocessor. + coarse_model_width: required width of input tensor of coarse model. + coarse_model_height: required height of input tensor of coarse model. + refine_model_width: required width of input tensor of refine model. + refine_model_height: required height of input tensor of refine model. + mean_vec: mean of dataset(for transforms.Normalize), default is mean of Imagenet dataset. + std_vec: standard deviation of dataset(for transforms.Normalize), default is std of Imagenet dataset. + """ + super().__init__(mode) + + # set preprocessor info + self.coarse_input_size = [coarse_model_width, coarse_model_height] + self.refine_input_size = [refine_model_width, refine_model_height] + self.normalize = transforms.Normalize(mean=mean_vec, std=std_vec) + + def __call__(self, data: Union[str, Dict], **kwargs) -> Dict[str, Any]: + """process the raw input data + Args: + data (dict): data dict containing following info: + sky_image, scene_image + example: + ```python + { + "sky_image": "xxx.jpg" # sky_image path(str) + "scene_image": "xxx.jpg", # scene_image path(str) + } + ``` + Returns: + Dict[str, Any]: the preprocessed data + { + "sky_image": the preprocessed sky image(origin size) + "sky_image_refine": the preprocessed resized sky image + "scene_image": the preprocessed scene image(origin size) + "scene_image_refine": the preprocessed resized scene image + "img_metas": informations of preprocessed images, e.g. origin shape, pad information, resized shape. + } + """ + if 'sky_image' not in data.keys(): + raise Exception('sky_image not in input data') + if 'scene_image' not in data.keys(): + raise Exception('scene_image not in input data') + if isinstance(data['sky_image'], str): + sky_image = LoadImage.convert_to_ndarray(data['sky_image']) + sky_image = sky_image.astype(np.uint8) # RGB + sky_image = cv2.cvtColor(sky_image, cv2.COLOR_RGB2BGR) # BGR + if sky_image is not None: + sky_image = self.check_image(sky_image) + else: + raise Exception('sky_image is None') + else: + raise Exception('sky_image(path of sky image) is not valid') + if isinstance(data['scene_image'], str): + scene_image = LoadImage.convert_to_ndarray(data['scene_image']) + scene_image = scene_image.astype(np.uint8) # RGB + scene_image = cv2.cvtColor(scene_image, cv2.COLOR_RGB2BGR) # BGR + if scene_image is not None: + scene_image = self.check_image(scene_image) + else: + raise Exception('scene_image is None') + else: + raise Exception('scene_image(path of scene image) is not valid') + data = {} + sky_image_refine, sky_img_metas = self.process_single_img(sky_image) + scene_image_refine, scene_img_metas = self.process_single_img( + scene_image) + data['sky_image'] = sky_image + data['sky_image_refine'] = sky_image_refine + data['scene_image'] = scene_image + data['scene_image_refine'] = scene_image_refine + data['img_metas'] = { + 'sky_img_metas': sky_img_metas, + 'scene_img_metas': scene_img_metas, + 'input_size': { + 'coarse_input_size': self.coarse_input_size, + 'refine_input_size': self.refine_input_size + } + } + return data + + def process_single_img(self, img): + img_metas = {} + img_metas['ori_shape'] = img.shape[0:2] # img: (origin_h, origin_w, 3) + img, pad_direction = get_refine_input(img, self.refine_input_size) + img = image_transform( + img, self.normalize) # torch.Size([3, refine_net_h, refine_net_w]) + img = img.unsqueeze(0) + img_metas['pad_direction'] = pad_direction + img_metas['refine_shape'] = img.shape[ + 2:] # torch.Size([1, 3, refine_net_h, refine_net_w]) + return img, img_metas + + def check_image(self, input_img): + whole_temp_shape = input_img.shape + if len(whole_temp_shape) == 2: + input_img = np.stack([input_img, input_img, input_img], axis=2) + elif whole_temp_shape[2] == 1: + input_img = np.concatenate([input_img, input_img, input_img], + axis=2) + elif whole_temp_shape[2] == 4: + input_img = input_img[:, :, + 0:3] * 1.0 * input_img[:, :, + 3:4] * 1.0 / 255.0 + return input_img + + +def get_refine_input(mat, refine_input_size): + # maxDimMatch: resize + mat = max_dim_match(mat, refine_input_size) + # pad image to refine net input size + mat, pad_direction = center_pad_image_withwh(mat, refine_input_size, 0) + return mat, pad_direction + + +def max_dim_match(image, refine_model_size): + h, w, c = np.shape(image) + resize_w, resize_h = refine_model_size + if h != resize_h or w != resize_w: + h_scale = float(resize_h) / h + w_scale = float(resize_w) / w + resize_scale = min(w_scale, h_scale) + new_h = int(h * resize_scale + 0.5) + new_w = int(w * resize_scale + 0.5) + image = cv2.resize( + image, (new_w, new_h), interpolation=cv2.INTER_LINEAR) + return image + + +def center_pad_image_withwh(image, + crop_size, + padvalue, + padding_mode='constant'): + pad_image = image + h, w = image.shape[0], image.shape[1] + pad_h = max(crop_size[1] - h, 0) + pad_w = max(crop_size[0] - w, 0) + pad_direction = (0, 0, 0, 0) + if pad_h > 0 or pad_w > 0: + half_w = int(pad_w / 2 + 0.5) + half_h = int(pad_h / 2 + 0.5) + pad_direction = (half_w, half_h, pad_w - half_w, pad_h - half_h) + pad_image = pad( + image, pad_direction, padvalue, padding_mode=padding_mode) + return pad_image, pad_direction + + +def pad(img, padding, fill=0, padding_mode='constant'): + if not is_numpy_image(img): + raise TypeError('img should be numpy ndarray. Got {}'.format( + type(img))) + if not isinstance(padding, + (numbers.Number, tuple, list)) or len(padding) != 4: + raise TypeError('Got inappropriate padding arg') + + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + + shape_len = len(img.shape) + if shape_len == 2: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_to_str[padding_mode], + value=fill, + ) + elif shape_len == 3 and img.shape[2] == 1: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_to_str[padding_mode], + value=fill, + )[:, :, np.newaxis] + else: + return cv2.copyMakeBorder( + img, + top=pad_top, + bottom=pad_bottom, + left=pad_left, + right=pad_right, + borderType=_cv2_pad_to_str[padding_mode], + value=fill, + ) + + +def image_transform(img, normalize): + img = img[:, :, ::-1] # BGR-->RGB to pil format + img = img.transpose((2, 0, 1)) # h,w,c --> c,h,w + img = img.astype(np.float32) / 255 + img = normalize(torch.from_numpy(img.copy())) + return img + + +def is_numpy_image(img): + return isinstance(img, np.ndarray) and (img.ndim in {2, 3}) diff --git a/modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py b/modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py new file mode 100644 index 00000000..507f3cc6 --- /dev/null +++ b/modelscope/models/cv/image_skychange/ptsemseg/BlockModules.py @@ -0,0 +1,118 @@ +# The implementation is adopted from ASPP made publicly available under the MIT License License +# at https://github.com/jfzhang95/pytorch-deeplab-xception +import torch +import torch.nn.functional as F +from torch import nn + +BatchNorm2d = nn.BatchNorm2d + + +class ASPPModule(nn.Module): + + def __init__(self, inplanes, planes, kernel_size, padding, dilation, + BatchNorm): + super(ASPPModule, self).__init__() + self.atrous_conv = nn.Conv2d( + inplanes, + planes, + kernel_size=kernel_size, + stride=1, + padding=padding, + dilation=dilation, + bias=False) + self.bn = BatchNorm(planes) + self.relu = nn.ReLU() + + self._init_weight() + + def forward(self, x): + x = self.atrous_conv(x) + x = self.bn(x) + + return self.relu(x) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + + +# this aspp is official version +# copy from :https://github.com/jfzhang95/pytorch-deeplab-xception +class ASPP(nn.Module): + + def __init__(self, inplanes, outplanes, dilations, drop_rate=0.1): + super(ASPP, self).__init__() + + self.aspp1 = ASPPModule( + inplanes, + outplanes, + 1, + padding=0, + dilation=dilations[0], + BatchNorm=BatchNorm2d) + self.aspp2 = ASPPModule( + inplanes, + outplanes, + 3, + padding=dilations[1], + dilation=dilations[1], + BatchNorm=BatchNorm2d) + self.aspp3 = ASPPModule( + inplanes, + outplanes, + 3, + padding=dilations[2], + dilation=dilations[2], + BatchNorm=BatchNorm2d) + self.aspp4 = ASPPModule( + inplanes, + outplanes, + 3, + padding=dilations[3], + dilation=dilations[3], + BatchNorm=BatchNorm2d) + + self.global_avg_pool = nn.Sequential( + nn.AdaptiveAvgPool2d((1, 1)), + nn.Conv2d(inplanes, outplanes, 1, stride=1, bias=False), + BatchNorm2d(outplanes), nn.ReLU()) + self.conv1 = nn.Conv2d(outplanes * 5, outplanes, 1, bias=False) + self.bn1 = BatchNorm2d(outplanes) + self.relu = nn.ReLU() + self.dropout = nn.Dropout(drop_rate) + self._init_weight() + + def forward(self, x): # [1, 256, 320, 320] + x1 = self.aspp1(x) # [1, 128, 160, 160] + x2 = self.aspp2(x) # [1, 128, 160, 160] + x3 = self.aspp3(x) # [1, 128, 160, 160] + x4 = self.aspp4(x) # [1, 128, 160, 160] + x5 = self.global_avg_pool(x) # b,c,h,w [1, 128, 1, 1] + x5 = F.interpolate( + x5, size=x4.size()[2:], mode='bilinear', + align_corners=True) # [1, 128, 160, 160] + x = torch.cat((x1, x2, x3, x4, x5), dim=1) # [1, 640, 160, 160] + + x = self.conv1(x) # [1, 640, 160, 160] + x = self.bn1(x) + x = self.relu(x) + + return self.dropout(x) + + def _init_weight(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + torch.nn.init.kaiming_normal_(m.weight) + elif isinstance(m, BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1) + m.bias.data.zero_() diff --git a/modelscope/models/cv/image_skychange/ptsemseg/__init__.py b/modelscope/models/cv/image_skychange/ptsemseg/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py b/modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py new file mode 100644 index 00000000..66429d67 --- /dev/null +++ b/modelscope/models/cv/image_skychange/ptsemseg/hrnet_backnone.py @@ -0,0 +1,620 @@ +# The implementation is adopted from HRNet, made publicly available under the MIT License License +# at https://github.com/HRNet/HRNet-Semantic-Segmentation +import logging +import os + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +BatchNorm2d = nn.BatchNorm2d + +BN_MOMENTUM = 0.1 +logger = logging.getLogger(__name__) + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(Bottleneck, self).__init__() + self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) + self.bn1 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + planes, + planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + self.bn2 = BatchNorm2d(planes, momentum=BN_MOMENTUM) + self.conv3 = nn.Conv2d( + planes, planes * self.expansion, kernel_size=1, bias=False) + self.bn3 = BatchNorm2d(planes * self.expansion, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + + def __init__(self, + num_branches, + blocks, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + multi_scale_output=True): + super(HighResolutionModule, self).__init__() + self._check_branches(num_branches, blocks, num_blocks, num_inchannels, + num_channels) + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches(num_branches, blocks, num_blocks, + num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, + branch_index, + block, + num_blocks, + num_channels, + stride=1): + downsample = None + if stride != 1 or self.num_inchannels[ + branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d( + num_channels[branch_index] * block.expansion, + momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample)) + self.num_inchannels[ + branch_index] = num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append( + block(self.num_inchannels[branch_index], + num_channels[branch_index])) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels # tuple + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + BatchNorm2d( + num_inchannels[i], momentum=BN_MOMENTUM))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i - j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + num_inchannels[j], + num_outchannels_conv3x3, + 3, + 2, + 1, + bias=False), + BatchNorm2d( + num_outchannels_conv3x3, + momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=(height_output, width_output), + mode='bilinear', + align_corners=True) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +model_w18v1 = { + 'STAGE1': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 1, + 'BLOCK': 'BOTTLENECK', + 'NUM_BLOCKS': (1), + 'NUM_CHANNELS': (32), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE2': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 2, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2), + 'NUM_CHANNELS': (16, 32), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE3': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 3, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2, 2), + 'NUM_CHANNELS': (16, 32, 64), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE4': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 4, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2, 2, 2), + 'NUM_CHANNELS': (16, 32, 64, 128), + 'FUSE_METHOD': 'SUM' + }, + 'FINAL_CONV_KERNEL': 1 +} + +model_w18v2 = { + 'STAGE1': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 1, + 'BLOCK': 'BOTTLENECK', + 'NUM_BLOCKS': (2), + 'NUM_CHANNELS': (64), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE2': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 2, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2), + 'NUM_CHANNELS': (18, 36), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE3': { + 'NUM_MODULES': 3, + 'NUM_BRANCHES': 3, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2, 2), + 'NUM_CHANNELS': (18, 36, 72), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE4': { + 'NUM_MODULES': 2, + 'NUM_BRANCHES': 4, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (2, 2, 2, 2), + 'NUM_CHANNELS': (18, 36, 72, 144), + 'FUSE_METHOD': 'SUM' + }, + 'FINAL_CONV_KERNEL': 1 +} + +model_w48 = { + 'STAGE1': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 1, + 'BLOCK': 'BOTTLENECK', + 'NUM_BLOCKS': (4), + 'NUM_CHANNELS': (64), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE2': { + 'NUM_MODULES': 1, + 'NUM_BRANCHES': 2, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (4, 4), + 'NUM_CHANNELS': (48, 96), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE3': { + 'NUM_MODULES': 4, + 'NUM_BRANCHES': 3, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (4, 4, 4), + 'NUM_CHANNELS': (48, 96, 192), + 'FUSE_METHOD': 'SUM' + }, + 'STAGE4': { + 'NUM_MODULES': 3, + 'NUM_BRANCHES': 4, + 'BLOCK': 'BASIC', + 'NUM_BLOCKS': (4, 4, 4, 4), + 'NUM_CHANNELS': (48, 96, 192, 384), + 'FUSE_METHOD': 'SUM' + }, + 'FINAL_CONV_KERNEL': 1 +} + +model_version_dict = {} +model_version_dict['w48'] = model_w48 +model_version_dict['w18v1'] = model_w18v1 +model_version_dict['w18v2'] = model_w18v2 + +blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} + + +class HrnetBackBone(nn.Module): + + def __init__(self, **kwargs): + super(HrnetBackBone, self).__init__() + + assert 'version' in kwargs, 'hrnet not exist model version' + extra = model_version_dict[kwargs['version']] + + # stem net + self.conv1 = nn.Conv2d( + 3, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn1 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.conv2 = nn.Conv2d( + 64, 64, kernel_size=3, stride=2, padding=1, bias=False) + self.bn2 = BatchNorm2d(64, momentum=BN_MOMENTUM) + self.relu = nn.ReLU(inplace=True) + + self.stage1_cfg = extra['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion * num_channels + + self.stage2_cfg = extra['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition1 = self._make_transition_layer([stage1_out_channel], + num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + self.stage3_cfg = extra['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition2 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + self.stage4_cfg = extra['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion + for i in range(len(num_channels)) + ] + self.transition3 = self._make_transition_layer(pre_stage_channels, + num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + self.backbone_last_inp_channels = np.int(np.sum(pre_stage_channels)) + + def _make_transition_layer(self, num_channels_pre_layer, + num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append( + nn.Sequential( + nn.Conv2d( + num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + BatchNorm2d( + num_channels_cur_layer[i], + momentum=BN_MOMENTUM), nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i + 1 - num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[ + i] if j == i - num_branches_pre else inchannels + conv3x3s.append( + nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + BatchNorm2d(outchannels, momentum=BN_MOMENTUM), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d( + inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + bias=False), + BatchNorm2d(planes * block.expansion, momentum=BN_MOMENTUM), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes)) + + return nn.Sequential(*layers) + + def _make_stage(self, + layer_config, + num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + modules.append( + HighResolutionModule(num_branches, block, num_blocks, + num_inchannels, num_channels, fuse_method, + reset_multi_scale_output)) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + def _backbone_forward(self, x): + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + + x = torch.cat([x[0], x1, x2, x3], 1) + return x + + def init_weights(self, url, cache_file=''): + pretrained_dict = load_state(url, model_dir=cache_file) + model_dict = self.state_dict() + + model_len = len(model_dict) + pretrain_len = len(pretrained_dict) + common_dict = {} + valid_layer_num = 0 + for k, v in pretrained_dict.items(): + if k in model_dict: + common_dict[k] = v + valid_layer_num += 1 + + print('*' * 50) + print('Model Param Num:{} Pretrained Param Num:{} ' + 'Commmon Num:{}'.format(model_len, pretrain_len, + valid_layer_num)) + print('-' * 50) + print('Model Extra Param Names:\n\t{}'.format( + set(model_dict) - set(pretrained_dict))) + print('-' * 50) + print('Pretrained Extra Param Names:\n\t{}'.format( + set(pretrained_dict) - set(model_dict))) + print('*' * 50) + + model_dict.update(common_dict) + self.load_state_dict(model_dict) diff --git a/modelscope/models/cv/image_skychange/ptsemseg/hrnet_super_and_ocr.py b/modelscope/models/cv/image_skychange/ptsemseg/hrnet_super_and_ocr.py new file mode 100644 index 00000000..09768451 --- /dev/null +++ b/modelscope/models/cv/image_skychange/ptsemseg/hrnet_super_and_ocr.py @@ -0,0 +1,510 @@ +# Part of the implementation is borrowed and modified from HRNet, +# publicly available under the MIT License License at https://github.com/HRNet/HRNet-Semantic-Segmentation +from __future__ import absolute_import, division, print_function + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .BlockModules import ASPP +from .hrnet_backnone import BatchNorm2d, HrnetBackBone, blocks_dict + +ALIGN_CORNERS = True +BN_MOMENTUM = 0.1 + + +class ModuleHelper: + + @staticmethod + def BNReLU(num_features, bn_type=None, **kwargs): + return nn.Sequential(BatchNorm2d(num_features, **kwargs), nn.ReLU()) + + @staticmethod + def BatchNorm2d(*args, **kwargs): + return BatchNorm2d + + +def conv3x3(in_planes, out_planes, stride=1): + """3x3 convolution with padding""" + return nn.Conv2d( + in_planes, + out_planes, + kernel_size=3, + stride=stride, + padding=1, + bias=False) + + +class SpatialGatherModule(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + + def __init__(self, cls_num=0, scale=1): + super(SpatialGatherModule, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, _, _ = probs.size(0), probs.size(1), probs.size( + 2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw + ocr_context = torch.matmul(probs, feats) # batch x k x c + + ocr_context = ocr_context.permute(0, 2, + 1).unsqueeze(3) # batch x c x k x 1 + return ocr_context + + +class ObjectAttentionBlock(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + + def __init__(self, in_channels, key_channels, scale=1, bn_type=None): + super(ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_object = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_down = nn.Sequential( + nn.Conv2d( + in_channels=self.in_channels, + out_channels=self.key_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_up = nn.Sequential( + nn.Conv2d( + in_channels=self.key_channels, + out_channels=self.in_channels, + kernel_size=1, + stride=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate( + input=context, + size=(h, w), + mode='bilinear', + align_corners=ALIGN_CORNERS) + + return context + + +class ObjectAttentionBlock2D(ObjectAttentionBlock): + + def __init__(self, in_channels, key_channels, scale=1, bn_type=None): + super(ObjectAttentionBlock2D, self).__init__( + in_channels, key_channels, scale, bn_type=bn_type) + + +class SpatialOCRModule(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + bn_type=None): + super(SpatialOCRModule, self).__init__() + self.object_context_block = ObjectAttentionBlock2D( + in_channels, key_channels, scale, bn_type) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d( + _in_channels, + out_channels, + kernel_size=1, + padding=0, + bias=False), + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), + nn.Dropout2d(dropout)) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + +class HrnetSuperAndOcr(HrnetBackBone): + + def __init__(self, **kwargs): + super(HrnetSuperAndOcr, self).__init__(**kwargs) + if 'architecture' not in kwargs: + raise Exception('HrnetSuperAndOcr not exist architecture param!') + self.architecture = kwargs['architecture'] + + if 'class_num' not in kwargs: + raise Exception('HrnetSuperAndOcr not exist class_num param!') + self.class_num = kwargs['class_num'] + + if 'ocr' not in kwargs: + raise Exception('HrnetSuperAndOcr not exist ocr param!') + ocr_mid_channels = kwargs['ocr']['mid_channels'] + ocr_key_channels = kwargs['ocr']['key_channels'] + dropout_rate = kwargs['ocr']['dropout_rate'] + scale = kwargs['ocr']['scale'] + + if 'super_param' not in kwargs: + raise Exception('HrnetSuperAndOcr not exist super_param param!') + + self.super_dict = kwargs['super_param'] + + self.is_export_onnx = False + self.is_export_full_onnx = False + + self.is_contain_tail = True if 'tail_param' in kwargs else False + if self.is_contain_tail: + self.stage_tail_dict = kwargs['tail_param'] + num_channels = self.stage_tail_dict['NUM_CHANNELS'][0] + block = blocks_dict[self.stage_tail_dict['BLOCK']] + num_blocks = self.stage_tail_dict['NUM_BLOCKS'][0] + self.stage_tail = self._make_layer(block, + self.backbone_last_inp_channels, + num_channels, num_blocks) + last_inp_channels = block.expansion * num_channels + else: + last_inp_channels = self.backbone_last_inp_channels + + self.is_contain_aspp = True if 'aspp' in kwargs else False + + if self.architecture == 'hrnet_super_ocr': + self.is_ocr_first = False + num_channels = [64, last_inp_channels] + self.stage_super, super_stage_channels = self._make_stage( + self.super_dict, num_channels) + last_inp_channels = np.int(np.sum(super_stage_channels)) + + if self.is_contain_aspp: + aspp_param = kwargs['aspp'] + self.aspp_layer = ASPP( + inplanes=last_inp_channels, + outplanes=aspp_param['outplanes'], + dilations=aspp_param['dilations'], + drop_rate=aspp_param['drop_rate']) + last_inp_channels = aspp_param['outplanes'] + + self.aux_head = nn.Sequential( + nn.Conv2d( + last_inp_channels, + last_inp_channels, + kernel_size=1, + stride=1, + padding=0), BatchNorm2d(last_inp_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + last_inp_channels, + self.class_num, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d( + last_inp_channels, + ocr_mid_channels, + kernel_size=3, + stride=1, + padding=1), + BatchNorm2d(ocr_mid_channels), + nn.ReLU(inplace=True), + ) + self.ocr_gather_head = SpatialGatherModule(self.class_num) + + self.ocr_distri_head = SpatialOCRModule( + in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=scale, + dropout=dropout_rate, + ) + + self.cls_head = nn.Sequential( + nn.Conv2d( + ocr_mid_channels, + ocr_mid_channels, + kernel_size=1, + stride=1, + padding=0), BatchNorm2d(ocr_mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + ocr_mid_channels, + self.class_num, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + else: + self.is_ocr_first = True + + if self.is_contain_aspp: + aspp_param = kwargs['aspp'] + self.aspp_layer = ASPP( + inplanes=last_inp_channels, + outplanes=aspp_param['outplanes'], + dilations=aspp_param['dilations'], + drop_rate=aspp_param['drop_rate']) + last_inp_channels = aspp_param['outplanes'] + + self.aux_head = nn.Sequential( + nn.Conv2d( + last_inp_channels, + last_inp_channels, + kernel_size=1, + stride=1, + padding=0), BatchNorm2d(last_inp_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + last_inp_channels, + self.class_num, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d( + last_inp_channels, + ocr_mid_channels, + kernel_size=3, + stride=1, + padding=1), + BatchNorm2d(ocr_mid_channels), + nn.ReLU(inplace=True), + ) + self.ocr_gather_head = SpatialGatherModule(self.class_num) + + self.ocr_distri_head = SpatialOCRModule( + in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=scale, + dropout=dropout_rate, + ) + + num_channels = [64, ocr_mid_channels] + self.stage_super, super_stage_channels = self._make_stage( + self.super_dict, num_channels) + last_inp_channels = np.int(np.sum(super_stage_channels)) + + self.cls_head = nn.Sequential( + nn.Conv2d( + last_inp_channels, + last_inp_channels, + kernel_size=1, + stride=1, + padding=0), BatchNorm2d(last_inp_channels), + nn.ReLU(inplace=True), + nn.Conv2d( + last_inp_channels, + self.class_num, + kernel_size=1, + stride=1, + padding=0, + bias=True)) + + def forward(self, x): + if self.is_export_onnx: + x = x.permute(0, 3, 1, 2) + raw_h, raw_w = x.size(2), x.size(3) + if self.is_export_full_onnx: + raw_h, raw_w = x.size(2), x.size(3) + + x = self.conv1(x) + x = self.bn1(x) # 5, 64, 320, 320 + x_stem = self.relu(x) + x = self.conv2(x_stem) + x = self.bn2(x) + x = self.relu(x) # 5, 64, 160, 160 + + x = self.layer1(x) # 5, 256=64*4, 160, 160 + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) # [[5, 18, 160, 160],[5, 36, 80, 80]] + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + # Upsampling + x0_h, x0_w = x[0].size(2), x[0].size(3) + x1 = F.interpolate( + x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + x2 = F.interpolate( + x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + x3 = F.interpolate( + x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=True) + + feats = torch.cat([x[0], x1, x2, x3], 1) + + if self.is_contain_tail: + feats = self.stage_tail(feats) + + if self.is_ocr_first: + + if self.is_contain_aspp: + feats = self.aspp_layer(feats) + # compute contrast feature + out_aux = self.aux_head(feats) + + feats = self.conv3x3_ocr(feats) + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + feats = [x_stem, feats] # 320*320 2X + x_super = self.stage_super(feats) + + xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3) + x_super1 = F.interpolate( + x_super[1], + size=(xsuper_h, xsuper_w), + mode='bilinear', + align_corners=True) + x_super = torch.cat([x_super[0], x_super1], 1) + out = self.cls_head(x_super) + + else: + x_super = [x_stem, feats] # 320*320 2X, 160*160 4X + x_super = self.stage_super(x_super) + + xsuper_h, xsuper_w = x_super[0].size(2), x_super[0].size(3) + x_super1 = F.interpolate( + x_super[1], + size=(xsuper_h, xsuper_w), + mode='bilinear', + align_corners=True) + x_super = torch.cat([x_super[0], x_super1], 1) + + if self.is_contain_aspp: + x_super = self.aspp_layer(x_super) + out_aux = self.aux_head(x_super) + + feats = self.conv3x3_ocr(x_super) + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + if self.is_export_onnx or self.is_export_full_onnx: + x_class = F.interpolate( + out, size=(raw_h, raw_w), mode='bilinear', align_corners=True) + x_class = torch.softmax(x_class, dim=1) + _, x_class = torch.max(x_class, dim=1, keepdim=True) + x_class = x_class.float() + return x_class + else: + out_aux_seg = [ + out_aux, out + ] # out_aux: 5, 2, 160, 160(HRNet origin res); out: 5, 2, 320, 320(HRNet res+tail+aspp+ocr) + return out_aux_seg + + +def get_seg_model(cfg, **kwargs): + model = HrnetSuperAndOcr(cfg, **kwargs) + model.init_weights(cfg.MODEL.PRETRAINED) + + return model diff --git a/modelscope/models/cv/image_skychange/ptsemseg/unet.py b/modelscope/models/cv/image_skychange/ptsemseg/unet.py new file mode 100644 index 00000000..20affbef --- /dev/null +++ b/modelscope/models/cv/image_skychange/ptsemseg/unet.py @@ -0,0 +1,229 @@ +# Copyright 2021-2022 The Alibaba Vision Team Authors. All rights reserved. +import torch +import torch.nn as nn + +from .BlockModules import ASPP + + +class Conv2DBatchNormRelu(nn.Module): + + def __init__(self, + in_channels, + n_filters, + k_size, + stride, + padding, + bias=True, + dilation=1, + with_bn=True, + with_relu=True): + super(Conv2DBatchNormRelu, self).__init__() + + conv_mod = nn.Conv2d( + int(in_channels), + int(n_filters), + kernel_size=k_size, + padding=padding, + stride=stride, + bias=bias, + dilation=dilation, + ) + + if with_bn: + if with_relu: + self.cbr_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters)), + nn.ReLU(inplace=True)) + else: + self.cbr_unit = nn.Sequential(conv_mod, + nn.BatchNorm2d(int(n_filters))) + else: + if with_relu: + self.cbr_unit = nn.Sequential(conv_mod, nn.ReLU(inplace=True)) + else: + self.cbr_unit = nn.Sequential(conv_mod) + + def forward(self, inputs): + outputs = self.cbr_unit(inputs) + return outputs + + +class SegnetDown2(nn.Module): + + def __init__(self, in_size, out_size): + super(SegnetDown2, self).__init__() + self.conv1 = Conv2DBatchNormRelu( + in_size, out_size, k_size=3, stride=1, padding=1) + self.conv2 = Conv2DBatchNormRelu( + out_size, out_size, k_size=3, stride=1, padding=1) + self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) + + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + unpooled_shape = outputs.size() + outputs, indices = self.maxpool_with_argmax(outputs) + return outputs, indices, unpooled_shape + + +class SegnetDown3(nn.Module): + + def __init__(self, in_size, out_size): + super(SegnetDown3, self).__init__() + self.conv1 = Conv2DBatchNormRelu( + in_size, out_size, k_size=3, stride=1, padding=1) + self.conv2 = Conv2DBatchNormRelu( + out_size, out_size, k_size=3, stride=1, padding=1) + self.conv3 = Conv2DBatchNormRelu( + out_size, out_size, k_size=3, stride=1, padding=1) + self.maxpool_with_argmax = nn.MaxPool2d(2, 2, return_indices=True) + + def forward(self, inputs): + outputs = self.conv1(inputs) + outputs = self.conv2(outputs) + outputs = self.conv3(outputs) + unpooled_shape = outputs.size() + outputs, indices = self.maxpool_with_argmax(outputs) + return outputs, indices, unpooled_shape + + +class SegnetUp1(nn.Module): + + def __init__(self, in_size, out_size): + super(SegnetUp1, self).__init__() + self.unpool = nn.MaxUnpool2d(2, 2) + self.conv = Conv2DBatchNormRelu( + in_size, out_size, k_size=5, stride=1, padding=2, with_relu=False) + + def forward(self, inputs, indices, output_shape): + outputs = self.unpool( + input=inputs, indices=indices, output_size=output_shape) + outputs = self.conv(outputs) + return outputs + + +class Unet(nn.Module): + + def __init__(self, + n_classes=2, + in_channels=4, + is_unpooling=True, + pretrain=True, + **kwargs): + super(Unet, self).__init__() + print('Load Unet') + self.in_channels = in_channels + self.is_unpooling = is_unpooling + self.pretrain = pretrain + self.is_contain_aspp = True if 'aspp' in kwargs else False + + if self.is_contain_aspp: + aspp_param = kwargs['aspp'] + self.aspp_layer = ASPP( + inplanes=128, + outplanes=aspp_param['outplanes'], + dilations=aspp_param['dilations'], + drop_rate=aspp_param['drop_rate']) + self.aspp_channels = aspp_param['outplanes'] + + self.down1 = SegnetDown2(self.in_channels, 64) + self.down2 = SegnetDown2(64, 128) + self.down3 = SegnetDown3(128, 256) + self.down4 = SegnetDown3(256, 512) + self.down5 = SegnetDown3(512, 512) + + self.up5 = SegnetUp1(512, 512) + self.up4 = SegnetUp1(512, 256) + self.up3 = SegnetUp1(256, 128) + + if self.is_contain_aspp: + self.conv_1x1_aspp = Conv2DBatchNormRelu( + 128 + self.aspp_channels, + 128, + k_size=1, + stride=1, + padding=0, + with_relu=False) + + self.up2 = SegnetUp1(128, 64) + self.up1 = SegnetUp1(64, n_classes) + self.sigmoid = nn.Sigmoid() + + if self.pretrain: + import torchvision.models as models + vgg16 = models.vgg16() + self.init_vgg16_params(vgg16) + + def forward(self, inputs): # [1, 4, 1346, 1152] [2, 4, 1280, 1280] + # inputs: [N, 4, 320, 320] + # outputs, indices, unpooled_shape + down1, indices_1, unpool_shape1 = self.down1( + inputs) # [1, 64, 673, 576] [2, 64, 640, 640] + down2, indices_2, unpool_shape2 = self.down2( + down1) # [1, 128, 336, 288] [2, 128, 320, 320] + down3, indices_3, unpool_shape3 = self.down3( + down2) # [1, 256, 168, 144] [2, 256, 160, 160] + torch.cuda.empty_cache() + if self.is_contain_aspp: # batchsize can not be 1 + aspp_output = self.aspp_layer(down2) + + down4, indices_4, unpool_shape4 = self.down4( + down3) # [1, 512, 84, 72] [2, 512, 80, 80] + down5, indices_5, unpool_shape5 = self.down5( + down4) # [1, 512, 42, 36] [2, 512, 80, 80] + torch.cuda.empty_cache() + up5 = self.up5(down5, indices_5, + unpool_shape5) # [1, 512, 84, 72] [2, 512, 80, 80] + up4 = self.up4(up5, indices_4, + unpool_shape4) # [1, 256, 168, 144] [2, 256, 160, 160] + torch.cuda.empty_cache() + up3 = self.up3( + up4, indices_3, + unpool_shape3) # [1, 128, 336, 288] [2, 128, 320, 320] + if self.is_contain_aspp: + up3 = torch.cat([up3, aspp_output], 1) # [2, 256, 320, 320] + up3 = self.conv_1x1_aspp(up3) # [2, 128, 320, 320] + + up2 = self.up2( + up3, indices_2, + unpool_shape2) # [1, 64, 673, 576] indices_2: [2, 128, 320, 320] + up1 = self.up1(up2, indices_1, unpool_shape1) # [1, 1, 1346, 1152] + + x = torch.squeeze(up1, dim=1) # [N, 1, 320, 320] -> [N, 320, 320] + x = self.sigmoid(x) + + return x # [2, 1280, 1280] + + def init_vgg16_params(self, vgg16): + blocks = [self.down1, self.down2, self.down3, self.down4, self.down5] + + features = list(vgg16.features.children()) + + vgg_layers = [] + for _layer in features: + if isinstance(_layer, nn.Conv2d): + vgg_layers.append(_layer) + + merged_layers = [] + for idx, conv_block in enumerate(blocks): + if idx < 2: + units = [conv_block.conv1.cbr_unit, conv_block.conv2.cbr_unit] + else: + units = [ + conv_block.conv1.cbr_unit, + conv_block.conv2.cbr_unit, + conv_block.conv3.cbr_unit, + ] + for _unit in units: + for _layer in _unit: + if isinstance(_layer, nn.Conv2d): + merged_layers.append(_layer) + + assert len(vgg_layers) == len(merged_layers) + + for l1, l2 in zip(vgg_layers, merged_layers): + if isinstance(l1, nn.Conv2d) and isinstance(l2, nn.Conv2d): + if l1.weight.size() == l2.weight.size() and l1.bias.size( + ) == l2.bias.size(): + l2.weight.data = l1.weight.data + l2.bias.data = l1.bias.data diff --git a/modelscope/models/cv/image_skychange/skychange.py b/modelscope/models/cv/image_skychange/skychange.py new file mode 100644 index 00000000..274753ac --- /dev/null +++ b/modelscope/models/cv/image_skychange/skychange.py @@ -0,0 +1,310 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numbers +import os +import pdb +from collections import deque + +import cv2 +import json +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from torchvision import transforms + +torch.backends.cudnn.enabled = True + +IMAGE_MAX_DIM = 3000 +IMAGE_MIN_DIM = 50 +IMAGE_MAX_RATIO = 10.0 +IMAGE_BLENDER_MASK_RESIZE_SCALE = 10.0 +IMAGE_BLENDER_INNER_RECT_MAX_DIM = 256 +IMAGE_BLENDER_DILATE_KERNEL_SIZE = 7 +IMAGE_BLENDER_VALID_MASK_THRESHOLD = 100 +IMAGE_BLENDER_MIN_VALID_SKY_AREA = 100 +IMAGE_BLENDER_MIN_RESIZE_DIM = 10 +IMAGE_BLENDER_BLUR_KERNEL_SIZE = 5 + + +def extract_sky_image(in_sky_image, in_sky_mask): + scale = 1.0 + resize_mask = in_sky_mask.copy() + + rows, cols = resize_mask.shape[0:2] + # src size: (512, 640), target size: (256,256), then scale to size (256, 320) + if (rows > IMAGE_BLENDER_INNER_RECT_MAX_DIM + or cols > IMAGE_BLENDER_INNER_RECT_MAX_DIM): + height_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(rows) + width_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(cols) + scale = height_scale if height_scale > width_scale else width_scale + new_size = (max(int(cols * scale), 1), max(int(rows * scale), + 1)) # w, h + resize_mask = cv2.resize(resize_mask, new_size, cv2.INTER_LINEAR) + + kernelSize = max(3, int(scale * IMAGE_BLENDER_DILATE_KERNEL_SIZE + 0.5)) + + element = cv2.getStructuringElement(cv2.MORPH_RECT, + (kernelSize, kernelSize)) + resize_mask = cv2.morphologyEx(resize_mask, cv2.MORPH_CLOSE, element) + + max_inner_rect, area = get_max_inner_rect( + resize_mask, IMAGE_BLENDER_VALID_MASK_THRESHOLD, True) + + if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA: + raise Exception( + '[extractSkyImage]failed!! Valid sky region is too small') + + scale = 1.0 / scale + # max_inner_rect: left top(x,y), right bottome(x,y); raw_inner_rect:left top x,y,w(of bbox),h(of bbox) + raw_inner_rect = scale_rect(max_inner_rect, in_sky_mask, scale) + out_sky_image = in_sky_image[raw_inner_rect[1]:raw_inner_rect[1] + + raw_inner_rect[3] + 1, + raw_inner_rect[0]:raw_inner_rect[0] + + raw_inner_rect[2] + 1, ].copy() + return out_sky_image + + +def blend(scene_image, scene_mask, sky_image, sky_mask, inBlendLevelNum=10): + if torch.cuda.is_available(): + scene_image = scene_image.cpu().numpy() + sky_image = sky_image.cpu().numpy() + else: + scene_image = scene_image.numpy() + sky_image = sky_image.numpy() + sky_image_h, sky_image_w = sky_image.shape[0:2] + sky_mask_h, sky_mask_w = sky_mask.shape[0:2] + + scene_image_h, scene_image_w = scene_image.shape[0:2] + scene_mask_h, scene_mask_w = scene_mask.shape[0:2] + + if sky_image_h != sky_mask_h or sky_image_w != sky_mask_w: + raise Exception( + '[blend]failed!! sky_image shape not equal with sky_image_mask shape' + ) + + if scene_image_h != scene_mask_h or scene_image_w != scene_mask_w: + raise Exception( + '[blend]failed!! scene_image shape not equal with scene_image_mask shape' + ) + + valid_sky_image = extract_sky_image(sky_image, sky_mask) + out_blend_image = blend_merge(scene_image, scene_mask, valid_sky_image, + inBlendLevelNum) + return out_blend_image + + +def get_max_inner_rect(in_image_mask, in_alpha_threshold, is_bigger_valid): + res = 0 + row, col = in_image_mask.shape[0:2] + i0, j0, i1, j1 = 0, 0, 0, 0 + height = [0] * (col + 1) + + for i in range(0, row): + s = deque() + for j in range(0, col + 1): + if j < col: + if is_bigger_valid: + height[j] = ( + height[j] + + 1 if in_image_mask[i, j] > in_alpha_threshold else 0) + else: + height[j] = ( + height[j] + 1 + if in_image_mask[i, j] <= in_alpha_threshold else 0) + + while len(s) != 0 and height[s[-1]] >= height[j]: + cur = s[-1] + s.pop() + _h = height[cur] + _w = j if len(s) == 0 else j - s[-1] - 1 + curArea = _h * _w + if curArea > res: + res = curArea + i1 = i + i0 = i1 - _h + 1 + j1 = j - 1 + j0 = j1 - _w + 1 + s.append(j) + + out_rect = ( + j0, + i0, + j1 - j0 + 1, + i1 - i0 + 1, + ) + return out_rect, res + + +def scale_rect(in_rect, in_image_size, in_scale): + tlX = int(in_rect[0] * in_scale + 0.5) + tlY = int(in_rect[1] * in_scale + 0.5) + in_image_size_h, in_image_size_w = in_image_size.shape[0:2] + brX = min(int(in_rect[2] * in_scale + 0.5), in_image_size_w) + brY = min(int(in_rect[3] * in_scale + 0.5), in_image_size_h) + out_rect = (tlX, tlY, brX - tlX, brY - tlY) + return out_rect + + +def get_fast_valid_rect(in_mask, in_threshold=0): + # mask: np.array [0~1] + in_mask = in_mask > in_threshold + locations = cv2.findNonZero(in_mask.astype(np.uint8)) + output_rect = cv2.boundingRect(locations) # x,y,w,h + return output_rect + + +def min_size_match(in_image, in_min_size, type=cv2.INTER_LINEAR): + resize_image = in_image.copy() + width, height = in_min_size + resize_img_height, resize_img_width = in_image.shape[0:2] + height_scale = height / resize_img_height + widht_scale = width / resize_img_width + scale = height_scale if height_scale > widht_scale else widht_scale + new_size = ( + max(int(resize_img_width * scale + 0.5), 1), + max(int(resize_img_height * scale + 0.5), 1), + ) + + resize_image = cv2.resize(resize_image, new_size, 0, 0, type) + return resize_image + + +def center_crop(in_image, in_size): + in_size_w, in_size_h = in_size + in_image_h, in_image_w = in_image.shape[0:2] + + half_height = (in_image_h - in_size_h) // 2 + half_width = (in_image_w - in_size_w) // 2 + + out_crop_image = in_image.copy() + out_crop_image = out_crop_image[half_height:half_height + in_size_h, + half_width:half_width + in_size_w] + return out_crop_image + + +def safe_roi_pad(in_pad_image, in_rect, out_base_image): + in_rect_x, in_rect_y, in_rect_w, in_rect_h = in_rect + + if in_rect_x < 0 or in_rect_y < 0 or in_rect_w <= 0 or in_rect_h <= 0: + raise Exception('[safe_roi_pad] Failed!! x,y,w,h of rect are illegal') + + if in_rect_w != in_pad_image.shape[1] or in_rect_h != in_pad_image.shape[0]: + raise Exception('[safe_roi_pad] Failed!!') + + if (in_rect_x + in_rect_w > out_base_image.shape[1] + or in_rect_y + in_rect_h > out_base_image.shape[0]): + raise Exception('[safe_roi_pad] Failed!!') + + out_base_image[in_rect_y:in_rect_y + in_rect_h, + in_rect_x:in_rect_x + in_rect_w] = in_pad_image + + +def merge_image(in_base_image, in_merge_image, in_merge_mask, in_point): + if in_merge_image.shape[0:2] != in_merge_mask.shape[0:2]: + raise Exception( + '[merge_image] Failed!! in_merge_image.shape != in_merge_mask.shape!!' + ) + + in_point_x, in_point_y = in_point + in_merge_image_rows, in_merge_image_cols = in_merge_image.shape[0:2] + in_base_image_rows, in_base_image_cols = in_base_image.shape[0:2] + + if (in_point_x + in_merge_image_cols > in_base_image_cols + or in_point_y + in_merge_image_rows > in_base_image_rows): + raise Exception( + '[merge_image] Failed!! merge_image:image rect not in image') + + base_roi_image = in_base_image[in_point_y:in_point_y + in_merge_image_rows, + in_point_x:in_point_x + + in_merge_image_cols, ] + + merge_image = in_merge_image.copy() + merge_alpha = in_merge_mask.copy() + base_roi_image = np.float32(base_roi_image) + merge_alpha = np.repeat(merge_alpha[:, :, np.newaxis], 3, axis=2) + merge_alpha = merge_alpha / 255.0 + + base_roi_image = ( + 1 - merge_alpha) * base_roi_image + merge_alpha * merge_image + base_roi_image = np.clip(base_roi_image, 0, 255) + base_roi_image = base_roi_image.astype('uint8') + + roi_rect = (in_point_x, in_point_y, in_merge_image_cols, + in_merge_image_rows) + safe_roi_pad(base_roi_image, roi_rect, in_base_image) + return in_base_image + + +def blend_merge(in_scene_image, + in_scene_mask, + in_valid_sky_image, + inBlendLevelNum=5): + scene_sky_rect = get_fast_valid_rect(in_scene_mask, 1) + area = scene_sky_rect[2] * scene_sky_rect[3] + + if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA: + raise Exception( + '[blend_merge] Failed!! Scene Image Valid sky region is too small') + + valid_sky_image = min_size_match(in_valid_sky_image, scene_sky_rect[2:]) + valid_sky_image = center_crop(valid_sky_image, scene_sky_rect[2:]) + + # resizeSceneMask + sky_size = ( + max( + int(in_scene_mask.shape[1] * IMAGE_BLENDER_MASK_RESIZE_SCALE + + 0.5), + IMAGE_BLENDER_MIN_RESIZE_DIM, + ), + max( + int(in_scene_mask.shape[0] * IMAGE_BLENDER_MASK_RESIZE_SCALE + + 0.5), + IMAGE_BLENDER_MIN_RESIZE_DIM, + ), + ) + + resize_scene_mask = cv2.resize(in_scene_mask, sky_size, cv2.INTER_LINEAR) + resize_scene_mask = cv2.blur( + resize_scene_mask, + (IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE), + ) + + element = cv2.getStructuringElement( + cv2.MORPH_RECT, + (IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE)) + sky_mask = cv2.dilate(resize_scene_mask, element) # enlarge sky region + scene_mask = cv2.erode(resize_scene_mask, element) # enlarge scene region + scene_mask = 255 - scene_mask + + sky_mask = cv2.resize(sky_mask, in_scene_mask.shape[0:2][::-1]) + scene_mask = cv2.resize(scene_mask, in_scene_mask.shape[0:2][::-1]) + + x, y, w, h = scene_sky_rect + valid_sky_mask = sky_mask[y:y + h, x:x + w] + + pano_sky_image = in_scene_image.copy() + + pano_sky_image = merge_image(pano_sky_image, valid_sky_image, + valid_sky_mask, scene_sky_rect[0:2]) + blend_images = [] + blend_images.append(in_scene_image) + blend_images.append(pano_sky_image) + + blend_masks = [] + blend_masks.append(scene_mask.astype(np.uint8)) + blend_masks.append(sky_mask.astype(np.uint8)) + + panorama_rect = (0, 0, in_scene_image.shape[1], in_scene_image.shape[0]) + + blender = cv2.detail_MultiBandBlender(1, inBlendLevelNum) + blender.prepare(panorama_rect) + + for i in range(0, len(blend_images)): + blender.feed(blend_images[i], blend_masks[i], (0, 0)) + pano_mask = ( + np.ones( + (in_scene_image.shape[1], in_scene_image.shape[0]), dtype='uint8') + * 255) + out_blend_image = np.zeros_like(in_scene_image) + result = blender.blend(out_blend_image, pano_mask) + return result[0] diff --git a/modelscope/models/cv/image_skychange/skychange_model.py b/modelscope/models/cv/image_skychange/skychange_model.py new file mode 100644 index 00000000..fa16c38c --- /dev/null +++ b/modelscope/models/cv/image_skychange/skychange_model.py @@ -0,0 +1,199 @@ +import math +import os +import pdb +import time +from collections import OrderedDict +from typing import Any, Dict, List, Union + +import cv2 +import json +import torch +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models import Model +from modelscope.models.base import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .ptsemseg.hrnet_super_and_ocr import HrnetSuperAndOcr +from .ptsemseg.unet import Unet +from .skychange import blend + +logger = get_logger() + + +@MODELS.register_module( + Tasks.image_skychange, module_name=Models.image_skychange) +class ImageSkychange(TorchModel): + + def __init__(self, model_dir, refine_cfg, coarse_cfg, *args, **kwargs): + """ + Args: + model_dir (str): model directory to initialize some resource. + refine_cfg: configuration of refine model. + coarse_cfg: configuration of coarse model. + """ + super().__init__(model_dir=model_dir, *args, **kwargs) + if torch.cuda.is_available(): + self.device = torch.device('cuda') + logger.info('Use GPU: {}'.format(self.device)) + else: + self.device = torch.device('cpu') + logger.info('Use CPU: {}'.format(self.device)) + + coarse_model_path = '{}/{}'.format(model_dir, + ModelFile.TORCH_MODEL_FILE) + refine_model_path = '{}/{}'.format(model_dir, + 'unet_sky_matting_final_model.pkl') + + logger.info( + '####################### load refine models ################################' + ) + self.refine_model = Unet(**refine_cfg['Model']) + self.load_model(self.refine_model, refine_model_path) + self.refine_model.eval() + logger.info( + '####################### load refine models done ############################' + ) + + logger.info( + '####################### load coarse models ################################' + ) + self.coarse_model = HrnetSuperAndOcr(**coarse_cfg['Model']) + self.load_model(self.coarse_model, coarse_model_path) + self.coarse_model.eval() + logger.info( + '####################### load coarse models done ############################' + ) + + def load_model(self, seg_model, input_model_path): + if not os.path.isfile(input_model_path): + logger.error( + '[checkModelPath]:model path dose not exits!!! model Path:' + + input_model_path) + raise Exception('[checkModelPath]:model path dose not exits!') + + if torch.cuda.is_available(): + checkpoint = torch.load(input_model_path) + model_state = self.convert_state_dict(checkpoint['model_state']) + seg_model.load_state_dict(model_state) + seg_model.to(self.device) + else: + checkpoint = torch.load(input_model_path, map_location='cpu') + model_state = self.convert_state_dict(checkpoint['model_state']) + seg_model.load_state_dict(model_state) + + def convert_state_dict(self, state_dict): + """Converts a state dict saved from a dataParallel module to normal + module state_dict inplace + :param state_dict is the loaded DataParallel model_state + """ + if not next(iter(state_dict)).startswith('module.'): + return state_dict # abort if dict is not a DataParallel model_state + new_state_dict = OrderedDict() + + split_index = 0 + for cur_key, cur_value in state_dict.items(): + if cur_key.startswith('module.model'): + split_index = 13 + elif cur_key.startswith('module'): + split_index = 7 + + break + + for k, v in state_dict.items(): + name = k[split_index:] # remove `module.` + new_state_dict[name] = v + return new_state_dict + + def forward( + self, + sky_image: torch.Tensor, + sky_image_refine: torch.Tensor, + scene_image: torch.Tensor, + scene_image_refine: torch.Tensor, + img_metas: Dict[str, Any], + ): + """ + Args: + sky_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h', w']. + sky_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w]. + scene_image (`torch.Tensor`): batched image tensor, shape is [1, 3, h, w]. + scene_image_refine (`torch.Tensor`): batched image tensor, shape is [1, 3, refine_net_h, refine_net_w]. + img_metas (`Dict[str, Any]`): image meta info. + Return: + `IMAGE: shape is [h, w, 3] (0~255)` + """ + start = time.time() + sky_img_metas, scene_img_metas, input_size = img_metas[ + 'sky_img_metas'], img_metas['scene_img_metas'], img_metas[ + 'input_size'] + sky_mask = self.inference_mask(sky_image_refine, sky_img_metas, + input_size) + scene_mask = self.inference_mask(scene_image_refine, scene_img_metas, + input_size) + end = time.time() + logger.info( + 'Time of inferencing mask of sky and scene images:{}'.format( + end - start)) + start = time.time() + scene_mask = scene_mask * 255 + sky_mask = sky_mask * 255 + res = blend(scene_image, scene_mask, sky_image, sky_mask) + end = time.time() + logger.info('Time of blending: {}'.format(end - start)) + return res + + @torch.no_grad() + def inference_mask(self, img, img_metas, input_size): + self.eval() + raw_h, raw_w = img_metas['ori_shape'] + pad_direction = img_metas['pad_direction'] + coarse_input_size = input_size['coarse_input_size'] + refine_input_size = input_size['refine_input_size'] + h, w = img_metas['refine_shape'] + resize_images = F.interpolate( + img, coarse_input_size, mode='bilinear', align_corners=True) + # get coarse result + pred_scores = self.coarse_model(resize_images) + if isinstance(pred_scores, (tuple, list)): + pred_scores = pred_scores[-1] + score = F.interpolate( + input=pred_scores, + size=refine_input_size, + mode='bilinear', + align_corners=True, + ) + _, coarse_pred = torch.max(score, dim=1) # [B, h, w] + coarse_pred = coarse_pred.unsqueeze(1).type(img.dtype) + img = torch.cat([img, coarse_pred], dim=1) # [B, c=4, h, w] + del resize_images + del pred_scores + del score + del coarse_pred + torch.cuda.empty_cache() + cur_scores = self.refine_model(img) + del img + torch.cuda.empty_cache() + cur_scores = torch.clip(cur_scores, 0, 1) + cur_scores = cur_scores.detach().cpu().numpy()[0] + + # resize if cur_scores shape are not compatible with origin image shape + ph, pw = cur_scores.shape + if ph != h or pw != w: + cur_scores = F.interpolate( + input=cur_scores, + size=(h, w), + mode='nearest', + align_corners=True) + # unpad to get valid area and resize to origin size + valid_cur_pred = cur_scores[pad_direction[1]:refine_input_size[0] + - pad_direction[3], + pad_direction[0]:refine_input_size[1] + - pad_direction[2], ] + valid_cur_pred = cv2.resize(valid_cur_pred, (raw_w, raw_h)) + del cur_scores + torch.cuda.empty_cache() + print('get refine mask done') + return valid_cur_pred diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index f2aaf48f..b1da7eb7 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -836,6 +836,11 @@ TASK_OUTPUTS = { # } Tasks.product_segmentation: [OutputKeys.MASKS], + # image_skychange result for a single sample + # { + # "output_img": np.ndarray with shape [height, width, 3] + # } + Tasks.image_skychange: [OutputKeys.OUTPUT_IMG], # { # 'scores': [0.1, 0.2, 0.3, ...] # } diff --git a/modelscope/pipeline_inputs.py b/modelscope/pipeline_inputs.py index 6d4f7794..50818dff 100644 --- a/modelscope/pipeline_inputs.py +++ b/modelscope/pipeline_inputs.py @@ -101,6 +101,10 @@ TASK_INPUTS = { 'img': InputType.IMAGE, 'mask': InputType.IMAGE, }, + Tasks.image_skychange: { + 'sky_image': InputType.IMAGE, + 'scene_image': InputType.IMAGE, + }, # image generation task result for a single image Tasks.image_to_image_generation: diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index c1a4d86b..fb1c53da 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -223,6 +223,8 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_swin-t_referring_video-object-segmentation'), Tasks.video_summarization: (Pipelines.video_summarization, 'damo/cv_googlenet_pgl-video-summarization'), + Tasks.image_skychange: (Pipelines.image_skychange, + 'damo/cv_hrnetocr_skychange'), Tasks.translation_evaluation: (Pipelines.translation_evaluation, 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 6e80a6b9..cff2138d 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -67,6 +67,7 @@ if TYPE_CHECKING: from .hand_static_pipeline import HandStaticPipeline from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline + from .image_skychange_pipeline import ImageSkychangePipeline from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline else: @@ -155,6 +156,7 @@ else: 'language_guided_video_summarization_pipeline': [ 'LanguageGuidedVideoSummarizationPipeline' ], + 'image_skychange_pipeline': ['ImageSkychangePipeline'], 'video_object_segmentation_pipeline': [ 'VideoObjectSegmentationPipeline' ], diff --git a/modelscope/pipelines/cv/image_skychange_pipeline.py b/modelscope/pipelines/cv/image_skychange_pipeline.py new file mode 100644 index 00000000..c71135b5 --- /dev/null +++ b/modelscope/pipelines/cv/image_skychange_pipeline.py @@ -0,0 +1,63 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import pdb +import time +from typing import Any, Dict, Union + +import cv2 +import numpy as np +import PIL + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.image_skychange import ImageSkyChangePreprocessor +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Model, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +@PIPELINES.register_module( + Tasks.image_skychange, module_name=Pipelines.image_skychange) +class ImageSkychangePipeline(Pipeline): + """ Image Sky Change Pipeline. Given two images(sky_image and scene_image), + pipeline will replace the sky style of sky_image with the sky style of scene_image. + Example: + + ```python + >>> from modelscope.pipelines import pipeline + >>> detector = pipeline('image-skychange', 'damo/cv_hrnetocr_skychange') + >>> detector({ + 'sky_image': 'sky_image.jpg', # sky_image path (str) + 'scene_image': 'scene_image.jpg', # scene_image path (str) + }) + { + "output_img": [H * W * 3] 0~255, we can use cv2.imwrite to save output_img as an image. + } + >>> # + ``` + """ + + def __init__(self, model: str, **kwargs): + """ + use `model` to create a image sky change pipeline for image editing + Args: + model (`str` or `Model`): model_id on modelscope hub + preprocessor(`Preprocessor`, *optional*, defaults to None): `ImageSkyChangePreprocessor`. + """ + super().__init__(model=model, **kwargs) + if not isinstance(self.model, Model): + logger.error('model object is not initialized.') + raise Exception('model object is not initialized.') + if self.preprocessor is None: + self.preprocessor = ImageSkyChangePreprocessor() + logger.info('load model done') + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + res = self.model.forward(**input) + return {OutputKeys.OUTPUT_IMG: res} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index d527e4c9..f7168656 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -67,6 +67,7 @@ class CVTasks(object): image_denoising = 'image-denoising' image_portrait_enhancement = 'image-portrait-enhancement' image_inpainting = 'image-inpainting' + image_skychange = 'image-skychange' # image generation image_to_image_translation = 'image-to-image-translation' diff --git a/tests/pipelines/test_image_skychange.py b/tests/pipelines/test_image_skychange.py new file mode 100644 index 00000000..cd0916e4 --- /dev/null +++ b/tests/pipelines/test_image_skychange.py @@ -0,0 +1,48 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +import os.path as osp +import unittest + +import cv2 +import torch + +import modelscope +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.test_utils import test_level + +print(modelscope.version.__release_datetime__) + + +class ImageSkychangeTest(unittest.TestCase): + + def setUp(self) -> None: + self.model = 'damo/cv_hrnetocr_skychange' + self.sky_image = 'data/test/images/sky_image.jpg' + self.scene_image = 'data/test/images/scene_image.jpg' + self.input = { + 'sky_image': self.sky_image, + 'scene_image': self.scene_image, + } + + def pipeline_inference(self, pipeline: Pipeline, input: str): + result = pipeline(input) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + image_skychange = pipeline(Tasks.image_skychange, model=self.model) + self.pipeline_inference(image_skychange, self.input) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + image_skychange = pipeline(Tasks.image_skychange) + self.pipeline_inference(image_skychange, self.input) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run_config.yaml b/tests/run_config.yaml index cb90852f..07cb3ddd 100644 --- a/tests/run_config.yaml +++ b/tests/run_config.yaml @@ -41,6 +41,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run - test_image_matting.py - test_skin_retouching.py - test_table_recognition.py + - test_image_skychange.py envs: default: # default env, case not in other env will in default, pytorch.