diff --git a/data/test/images/butterfly_lrx2_y.png b/data/test/images/butterfly_lrx2_y.png new file mode 100644 index 00000000..1598e075 --- /dev/null +++ b/data/test/images/butterfly_lrx2_y.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:430575a8cb668113d6b0e91e403be0c0e36a95bbb96c484603a625b52f71edd9 +size 11858 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index a5c2e8e7..3471bf47 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -86,6 +86,7 @@ class Models(object): image_probing_model = 'image-probing-model' defrcn = 'defrcn' image_face_fusion = 'image-face-fusion' + ecbsr = 'ecbsr' msrresnet_lite = 'msrresnet-lite' object_detection_3d = 'object_detection_3d' ddpm = 'ddpm' @@ -333,6 +334,7 @@ class Pipelines(object): ddpm_image_semantic_segmentation = 'ddpm-image-semantic-segmentation' video_colorization = 'video-colorization' motion_generattion = 'mdm-motion-generation' + mobile_image_super_resolution = 'mobile-image-super-resolution' object_detection_3d_depe = 'object-detection-3d-depe' nerf_recon_acc = 'nerf-recon-acc' diff --git a/modelscope/models/cv/super_resolution/__init__.py b/modelscope/models/cv/super_resolution/__init__.py index 5065e280..7187c57a 100644 --- a/modelscope/models/cv/super_resolution/__init__.py +++ b/modelscope/models/cv/super_resolution/__init__.py @@ -5,9 +5,13 @@ from modelscope.utils.import_utils import LazyImportModule if TYPE_CHECKING: from .rrdbnet_arch import RRDBNet + from .ecbsr_model import ECBSRModel else: - _import_structure = {'rrdbnet_arch': ['RRDBNet']} + _import_structure = { + 'rrdbnet_arch': ['RRDBNet'], + 'ecbsr_model': ['ECBSRModel'] + } import sys diff --git a/modelscope/models/cv/super_resolution/ecb.py b/modelscope/models/cv/super_resolution/ecb.py new file mode 100644 index 00000000..4ddf734c --- /dev/null +++ b/modelscope/models/cv/super_resolution/ecb.py @@ -0,0 +1,272 @@ +# The implementation is adopted from ECBSR, +# made publicly available under the Apache 2.0 License at +# https://github.com/xindongzhang/ECBSR/blob/main/models/ecb.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SeqConv3x3(nn.Module): + + def __init__(self, seq_type, inp_planes, out_planes, depth_multiplier): + super(SeqConv3x3, self).__init__() + + self.type = seq_type + self.inp_planes = inp_planes + self.out_planes = out_planes + + if self.type == 'conv1x1-conv3x3': + self.mid_planes = int(out_planes * depth_multiplier) + conv0 = torch.nn.Conv2d( + self.inp_planes, self.mid_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + conv1 = torch.nn.Conv2d( + self.mid_planes, self.out_planes, kernel_size=3) + self.k1 = conv1.weight + self.b1 = conv1.bias + + elif self.type == 'conv1x1-sobelx': + conv0 = torch.nn.Conv2d( + self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(scale) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(bias) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), + dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 1, 0] = 2.0 + self.mask[i, 0, 2, 0] = 1.0 + self.mask[i, 0, 0, 2] = -1.0 + self.mask[i, 0, 1, 2] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-sobely': + conv0 = torch.nn.Conv2d( + self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), + dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 0] = 1.0 + self.mask[i, 0, 0, 1] = 2.0 + self.mask[i, 0, 0, 2] = 1.0 + self.mask[i, 0, 2, 0] = -1.0 + self.mask[i, 0, 2, 1] = -2.0 + self.mask[i, 0, 2, 2] = -1.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + + elif self.type == 'conv1x1-laplacian': + conv0 = torch.nn.Conv2d( + self.inp_planes, self.out_planes, kernel_size=1, padding=0) + self.k0 = conv0.weight + self.b0 = conv0.bias + + # init scale & bias + scale = torch.randn(size=(self.out_planes, 1, 1, 1)) * 1e-3 + self.scale = nn.Parameter(torch.FloatTensor(scale)) + # bias = 0.0 + # bias = [bias for c in range(self.out_planes)] + # bias = torch.FloatTensor(bias) + bias = torch.randn(self.out_planes) * 1e-3 + bias = torch.reshape(bias, (self.out_planes, )) + self.bias = nn.Parameter(torch.FloatTensor(bias)) + # init mask + self.mask = torch.zeros((self.out_planes, 1, 3, 3), + dtype=torch.float32) + for i in range(self.out_planes): + self.mask[i, 0, 0, 1] = 1.0 + self.mask[i, 0, 1, 0] = 1.0 + self.mask[i, 0, 1, 2] = 1.0 + self.mask[i, 0, 2, 1] = 1.0 + self.mask[i, 0, 1, 1] = -4.0 + self.mask = nn.Parameter(data=self.mask, requires_grad=False) + else: + raise ValueError('the type of seqconv is not supported!') + + def forward(self, x): + if self.type == 'conv1x1-conv3x3': + # conv-1x1 + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d(input=y0, weight=self.k1, bias=self.b1, stride=1) + else: + y0 = F.conv2d(input=x, weight=self.k0, bias=self.b0, stride=1) + # explicitly padding with bias + y0 = F.pad(y0, (1, 1, 1, 1), 'constant', 0) + b0_pad = self.b0.view(1, -1, 1, 1) + y0[:, :, 0:1, :] = b0_pad + y0[:, :, -1:, :] = b0_pad + y0[:, :, :, 0:1] = b0_pad + y0[:, :, :, -1:] = b0_pad + # conv-3x3 + y1 = F.conv2d( + input=y0, + weight=self.scale * self.mask, + bias=self.bias, + stride=1, + groups=self.out_planes) + return y1 + + def rep_params(self): + device = self.k0.get_device() + if device < 0: + device = None + + if self.type == 'conv1x1-conv3x3': + # re-param conv kernel + RK = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones( + 1, self.mid_planes, 3, 3, device=device) * self.b0.view( + 1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=self.k1).view(-1, ) + self.b1 + else: + tmp = self.scale * self.mask + k1 = torch.zeros((self.out_planes, self.out_planes, 3, 3), + device=device) + for i in range(self.out_planes): + k1[i, i, :, :] = tmp[i, 0, :, :] + b1 = self.bias + # re-param conv kernel + RK = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3)) + # re-param conv bias + RB = torch.ones( + 1, self.out_planes, 3, 3, device=device) * self.b0.view( + 1, -1, 1, 1) + RB = F.conv2d(input=RB, weight=k1).view(-1, ) + b1 + return RK, RB + + +class ECB(nn.Module): + + def __init__(self, + inp_planes, + out_planes, + depth_multiplier, + act_type='prelu', + with_idt=False): + super(ECB, self).__init__() + + self.depth_multiplier = depth_multiplier + self.inp_planes = inp_planes + self.out_planes = out_planes + self.act_type = act_type + + if with_idt and (self.inp_planes == self.out_planes): + self.with_idt = True + else: + self.with_idt = False + + self.conv3x3 = torch.nn.Conv2d( + self.inp_planes, self.out_planes, kernel_size=3, padding=1) + self.conv1x1_3x3 = SeqConv3x3('conv1x1-conv3x3', self.inp_planes, + self.out_planes, self.depth_multiplier) + self.conv1x1_sbx = SeqConv3x3('conv1x1-sobelx', self.inp_planes, + self.out_planes, -1) + self.conv1x1_sby = SeqConv3x3('conv1x1-sobely', self.inp_planes, + self.out_planes, -1) + self.conv1x1_lpl = SeqConv3x3('conv1x1-laplacian', self.inp_planes, + self.out_planes, -1) + + if self.act_type == 'prelu': + self.act = nn.PReLU(num_parameters=self.out_planes) + elif self.act_type == 'relu': + self.act = nn.ReLU(inplace=True) + elif self.act_type == 'rrelu': + self.act = nn.RReLU(lower=-0.05, upper=0.05) + elif self.act_type == 'softplus': + self.act = nn.Softplus() + elif self.act_type == 'linear': + pass + else: + raise ValueError('The type of activation if not support!') + + def forward(self, x): + if self.training: + y = self.conv3x3(x) + \ + self.conv1x1_3x3(x) + \ + self.conv1x1_sbx(x) + \ + self.conv1x1_sby(x) + \ + self.conv1x1_lpl(x) + if self.with_idt: + y += x + else: + RK, RB = self.rep_params() + y = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + if self.act_type != 'linear': + y = self.act(y) + return y + + def rep_params(self): + K0, B0 = self.conv3x3.weight, self.conv3x3.bias + K1, B1 = self.conv1x1_3x3.rep_params() + K2, B2 = self.conv1x1_sbx.rep_params() + K3, B3 = self.conv1x1_sby.rep_params() + K4, B4 = self.conv1x1_lpl.rep_params() + RK, RB = (K0 + K1 + K2 + K3 + K4), (B0 + B1 + B2 + B3 + B4) + + if self.with_idt: + device = RK.get_device() + if device < 0: + device = None + K_idt = torch.zeros( + self.out_planes, self.out_planes, 3, 3, device=device) + for i in range(self.out_planes): + K_idt[i, i, 1, 1] = 1.0 + B_idt = 0.0 + RK, RB = RK + K_idt, RB + B_idt + return RK, RB + + +if __name__ == '__main__': + + # # test seq-conv + x = torch.randn(1, 3, 5, 5).cuda() + conv = SeqConv3x3('conv1x1-conv3x3', 3, 3, 2).cuda() + y0 = conv(x) + RK, RB = conv.rep_params() + y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + print(y0 - y1) + + # test ecb + x = torch.randn(1, 3, 5, 5).cuda() * 200 + ecb = ECB(3, 3, 2, act_type='linear', with_idt=True).cuda() + y0 = ecb(x) + + RK, RB = ecb.rep_params() + y1 = F.conv2d(input=x, weight=RK, bias=RB, stride=1, padding=1) + print(y0 - y1) diff --git a/modelscope/models/cv/super_resolution/ecbsr_model.py b/modelscope/models/cv/super_resolution/ecbsr_model.py new file mode 100644 index 00000000..6f54d2e4 --- /dev/null +++ b/modelscope/models/cv/super_resolution/ecbsr_model.py @@ -0,0 +1,102 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from typing import Any, Dict, Union + +import torch +import torch.cuda +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger +from .ecb import ECB + +logger = get_logger() +__all__ = ['ECBSRModel'] + + +@MODELS.register_module(Tasks.image_super_resolution, module_name=Models.ecbsr) +class ECBSRModel(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the image denoise model from the `model_dir` path. + + Args: + model_dir (str): the model path. + + """ + super().__init__(model_dir, *args, **kwargs) + self.model_dir = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + # network architecture + self.module_nums = self.config.model.model_args.module_nums + self.channel_nums = self.config.model.model_args.channel_nums + self.scale = self.config.model.model_args.scale + self.colors = self.config.model.model_args.colors + self.with_idt = self.config.model.model_args.with_idt + self.act_type = self.config.model.model_args.act_type + + backbone = [] + backbone += [ + ECB(self.colors, + self.channel_nums, + depth_multiplier=2.0, + act_type=self.act_type, + with_idt=self.with_idt) + ] + for i in range(self.module_nums): + backbone += [ + ECB(self.channel_nums, + self.channel_nums, + depth_multiplier=2.0, + act_type=self.act_type, + with_idt=self.with_idt) + ] + backbone += [ + ECB(self.channel_nums, + self.colors * self.scale * self.scale, + depth_multiplier=2.0, + act_type='linear', + with_idt=self.with_idt) + ] + + self.backbone = nn.Sequential(*backbone) + self.upsampler = nn.PixelShuffle(self.scale) + + self.interp = nn.Upsample(scale_factor=self.scale, mode='nearest') + + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + output = self.backbone(input) + output = self.upsampler(output) + self.interp(input) + return {'outputs': output} + + def forward(self, inputs: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + inputs (Tensor): the preprocessed data + + Returns: + Dict[str, Tensor]: results + """ + return self._inference_forward(**inputs) + + @classmethod + def _instantiate(cls, **kwargs): + model_file = kwargs.get('am_model_name', ModelFile.TORCH_MODEL_FILE) + model_dir = kwargs['model_dir'] + ckpt_path = os.path.join(model_dir, model_file) + logger.info(f'loading model from {ckpt_path}') + model_dir = kwargs.pop('model_dir') + model = cls(model_dir=model_dir, **kwargs) + ckpt_path = os.path.join(model_dir, model_file) + model.load_state_dict(torch.load(ckpt_path, map_location='cpu')) + return model diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index bd9af367..c69dd37c 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -100,6 +100,7 @@ if TYPE_CHECKING: from .ddpm_semantic_segmentation_pipeline import DDPMImageSemanticSegmentationPipeline from .image_inpainting_sdv2_pipeline import ImageInpaintingSDV2Pipeline from .image_quality_assessment_mos_pipeline import ImageQualityAssessmentMosPipeline + from .mobile_image_super_resolution_pipeline import MobileImageSuperResolutionPipeline from .nerf_recon_acc_pipeline import NeRFReconAccPipeline else: @@ -245,6 +246,9 @@ else: 'image_quality_assessment_mos_pipeline': [ 'ImageQualityAssessmentMosPipeline' ], + 'mobile_image_super_resolution_pipeline': [ + 'MobileImageSuperResolutionPipeline' + ], 'nerf_recon_acc_pipeline': ['NeRFReconAccPipeline'], } diff --git a/modelscope/pipelines/cv/mobile_image_super_resolution_pipeline.py b/modelscope/pipelines/cv/mobile_image_super_resolution_pipeline.py new file mode 100644 index 00000000..4ff98c8f --- /dev/null +++ b/modelscope/pipelines/cv/mobile_image_super_resolution_pipeline.py @@ -0,0 +1,104 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import Any, Dict, Optional, Union + +import numpy as np +import skimage.color as sc +import torch +from torchvision import transforms + +from modelscope.metainfo import Pipelines +from modelscope.models import Model +from modelscope.models.cv.super_resolution import ECBSRModel +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.utils.constant import Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + +__all__ = ['MobileImageSuperResolutionPipeline'] + + +@PIPELINES.register_module( + Tasks.image_super_resolution, + module_name=Pipelines.mobile_image_super_resolution) +class MobileImageSuperResolutionPipeline(Pipeline): + + def __init__(self, + model: Union[ECBSRModel, str], + preprocessor=None, + **kwargs): + """The inference pipeline for all the image super-resolution tasks. + + Args: + model (`str` or `Model` or module instance): A model instance or a model local dir + or a model id in the model hub. + preprocessor (`Preprocessor`, `optional`): A Preprocessor instance. + kwargs (dict, `optional`): + Extra kwargs passed into the preprocessor's constructor. + + Example: + >>> from modelscope.pipelines import pipeline + >>> import cv2 + >>> from modelscope.outputs import OutputKeys + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> sr = pipeline(Tasks.image_super_resolution, model='damo/cv_ecbsr_image-super-resolution_mobile') + >>> result = sr('data/test/images/butterfly_lrx2_y.png') + >>> cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + """ + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + self.model.eval() + self.config = self.model.config + + self.y_input = self.model.config.model.y_input + self.tensor_max_value = self.model.config.model.tensor_max_value + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + logger.info('load image mobile sr model done') + + def preprocess(self, input: Input) -> Dict[str, Any]: + img = LoadImage.convert_to_img(input) + + if self.y_input: + img = sc.rgb2ycbcr(img)[:, :, 0:1] + + img = np.ascontiguousarray(img.transpose((2, 0, 1))) + img = torch.from_numpy(img).to(self._device) + + img = img.float() + if self.tensor_max_value == 1.0: + img /= 255.0 + + result = {'input': img.unsqueeze(0)} + return result + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + + def set_phase(model, is_train): + if is_train: + model.train() + else: + model.eval() + + is_train = False + set_phase(self.model, is_train) + with torch.no_grad(): + output = self.model(input) # output Tensor + + return {'output_tensor': output['outputs']} + + def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]: + output = input['output_tensor'].squeeze(0) + if self.tensor_max_value == 1.0: + output *= 255.0 + + output = output.clamp(0, 255).to(torch.uint8) + output = output.permute(1, 2, 0).contiguous().cpu().numpy() + + return {OutputKeys.OUTPUT_IMG: output} diff --git a/tests/pipelines/test_mobile_image_super_resolution.py b/tests/pipelines/test_mobile_image_super_resolution.py new file mode 100644 index 00000000..2cc7adf0 --- /dev/null +++ b/tests/pipelines/test_mobile_image_super_resolution.py @@ -0,0 +1,47 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os.path as osp +import unittest + +import cv2 + +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.base import Pipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class MobileImageSuperResolutionTest(unittest.TestCase, + DemoCompatibilityCheck): + + def setUp(self) -> None: + self.model_id = 'damo/cv_ecbsr_image-super-resolution_mobile' + self.img = 'data/test/images/butterfly_lrx2_y.png' + self.task = Tasks.image_super_resolution + + def pipeline_inference(self, pipeline: Pipeline, img: str): + result = pipeline(img) + if result is not None: + cv2.imwrite('result.png', result[OutputKeys.OUTPUT_IMG]) + print(f'Output written to {osp.abspath("result.png")}') + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_modelhub(self): + super_resolution = pipeline( + Tasks.image_super_resolution, model=self.model_id) + + self.pipeline_inference(super_resolution, self.img) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_modelhub_default_model(self): + super_resolution = pipeline(Tasks.image_super_resolution) + self.pipeline_inference(super_resolution, self.img) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()