mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
[to #42322933][730 model] add image denoise
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9491966
This commit is contained in:
3
data/test/images/noisy-demo-0.png
Normal file
3
data/test/images/noisy-demo-0.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:403034182fa320130dae0d75b92e85e0850771378e674d65455c403a4958e29c
|
||||
size 170716
|
||||
3
data/test/images/noisy-demo-1.png
Normal file
3
data/test/images/noisy-demo-1.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:ebd5dacad9b75ef80f87eb785d7818421dadb63257da0e91e123766c5913f855
|
||||
size 149971
|
||||
@@ -10,6 +10,7 @@ class Models(object):
|
||||
Model name should only contain model info but not task info.
|
||||
"""
|
||||
# vision models
|
||||
nafnet = 'nafnet'
|
||||
csrnet = 'csrnet'
|
||||
cascade_mask_rcnn_swin = 'cascade_mask_rcnn_swin'
|
||||
|
||||
@@ -59,6 +60,7 @@ class Pipelines(object):
|
||||
"""
|
||||
# vision tasks
|
||||
image_matting = 'unet-image-matting'
|
||||
image_denoise = 'nafnet-image-denoise'
|
||||
person_image_cartoon = 'unet-person-image-cartoon'
|
||||
ocr_detection = 'resnet18-ocr-detection'
|
||||
action_recognition = 'TAdaConv_action-recognition'
|
||||
@@ -132,6 +134,7 @@ class Preprocessors(object):
|
||||
|
||||
# cv preprocessor
|
||||
load_image = 'load-image'
|
||||
image_denoie_preprocessor = 'image-denoise-preprocessor'
|
||||
image_color_enhance_preprocessor = 'image-color-enhance-preprocessor'
|
||||
image_instance_segmentation_preprocessor = 'image-instance-segmentation-preprocessor'
|
||||
|
||||
@@ -167,6 +170,9 @@ class Metrics(object):
|
||||
# accuracy
|
||||
accuracy = 'accuracy'
|
||||
|
||||
# metrics for image denoise task
|
||||
image_denoise_metric = 'image-denoise-metric'
|
||||
|
||||
# metric for image instance segmentation task
|
||||
image_ins_seg_coco_metric = 'image-ins-seg-coco-metric'
|
||||
# metrics for sequence classification task
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from .base import Metric
|
||||
from .builder import METRICS, build_metric, task_default_metrics
|
||||
from .image_color_enhance_metric import ImageColorEnhanceMetric
|
||||
from .image_denoise_metric import ImageDenoiseMetric
|
||||
from .image_instance_segmentation_metric import \
|
||||
ImageInstanceSegmentationCOCOMetric
|
||||
from .sequence_classification_metric import SequenceClassificationMetric
|
||||
|
||||
@@ -22,6 +22,7 @@ task_default_metrics = {
|
||||
Tasks.sentence_similarity: [Metrics.seq_cls_metric],
|
||||
Tasks.sentiment_classification: [Metrics.seq_cls_metric],
|
||||
Tasks.text_generation: [Metrics.text_gen_metric],
|
||||
Tasks.image_denoise: [Metrics.image_denoise_metric],
|
||||
Tasks.image_color_enhance: [Metrics.image_color_enhance_metric]
|
||||
}
|
||||
|
||||
|
||||
45
modelscope/metrics/image_denoise_metric.py
Normal file
45
modelscope/metrics/image_denoise_metric.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from typing import Dict
|
||||
|
||||
import numpy as np
|
||||
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.utils.registry import default_group
|
||||
from modelscope.utils.tensor_utils import (torch_nested_detach,
|
||||
torch_nested_numpify)
|
||||
from .base import Metric
|
||||
from .builder import METRICS, MetricKeys
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group, module_name=Metrics.image_denoise_metric)
|
||||
class ImageDenoiseMetric(Metric):
|
||||
"""The metric computation class for image denoise classes.
|
||||
"""
|
||||
pred_name = 'pred'
|
||||
label_name = 'target'
|
||||
|
||||
def __init__(self):
|
||||
self.preds = []
|
||||
self.labels = []
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
ground_truths = outputs[ImageDenoiseMetric.label_name]
|
||||
eval_results = outputs[ImageDenoiseMetric.pred_name]
|
||||
self.preds.append(
|
||||
torch_nested_numpify(torch_nested_detach(eval_results)))
|
||||
self.labels.append(
|
||||
torch_nested_numpify(torch_nested_detach(ground_truths)))
|
||||
|
||||
def evaluate(self):
|
||||
psnr_list, ssim_list = [], []
|
||||
for (pred, label) in zip(self.preds, self.labels):
|
||||
psnr_list.append(
|
||||
peak_signal_noise_ratio(label[0], pred[0], data_range=255))
|
||||
ssim_list.append(
|
||||
structural_similarity(
|
||||
label[0], pred[0], multichannel=True, data_range=255))
|
||||
return {
|
||||
MetricKeys.PSNR: np.mean(psnr_list),
|
||||
MetricKeys.SSIM: np.mean(ssim_list)
|
||||
}
|
||||
@@ -22,6 +22,7 @@ except ModuleNotFoundError as e:
|
||||
|
||||
try:
|
||||
from .multi_modal import OfaForImageCaptioning
|
||||
from .cv import NAFNetForImageDenoise
|
||||
from .nlp import (BertForMaskedLM, BertForSequenceClassification,
|
||||
SbertForNLI, SbertForSentenceSimilarity,
|
||||
SbertForSentimentClassification,
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from .image_color_enhance.image_color_enhance import ImageColorEnhance
|
||||
from .image_denoise.nafnet_for_image_denoise import * # noqa F403
|
||||
|
||||
0
modelscope/models/cv/image_denoise/__init__.py
Normal file
0
modelscope/models/cv/image_denoise/__init__.py
Normal file
233
modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py
Normal file
233
modelscope/models/cv/image_denoise/nafnet/NAFNet_arch.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .arch_util import LayerNorm2d
|
||||
|
||||
|
||||
class SimpleGate(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
return x1 * x2
|
||||
|
||||
|
||||
class NAFBlock(nn.Module):
|
||||
|
||||
def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
|
||||
super().__init__()
|
||||
dw_channel = c * DW_Expand
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels=c,
|
||||
out_channels=dw_channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
in_channels=dw_channel,
|
||||
out_channels=dw_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
groups=dw_channel,
|
||||
bias=True)
|
||||
self.conv3 = nn.Conv2d(
|
||||
in_channels=dw_channel // 2,
|
||||
out_channels=c,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
|
||||
# Simplified Channel Attention
|
||||
self.sca = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d(1),
|
||||
nn.Conv2d(
|
||||
in_channels=dw_channel // 2,
|
||||
out_channels=dw_channel // 2,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True),
|
||||
)
|
||||
|
||||
# SimpleGate
|
||||
self.sg = SimpleGate()
|
||||
|
||||
ffn_channel = FFN_Expand * c
|
||||
self.conv4 = nn.Conv2d(
|
||||
in_channels=c,
|
||||
out_channels=ffn_channel,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
self.conv5 = nn.Conv2d(
|
||||
in_channels=ffn_channel // 2,
|
||||
out_channels=c,
|
||||
kernel_size=1,
|
||||
padding=0,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
|
||||
self.norm1 = LayerNorm2d(c)
|
||||
self.norm2 = LayerNorm2d(c)
|
||||
|
||||
self.dropout1 = nn.Dropout(
|
||||
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
||||
self.dropout2 = nn.Dropout(
|
||||
drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
||||
|
||||
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||||
self.gamma = nn.Parameter(
|
||||
torch.zeros((1, c, 1, 1)), requires_grad=True)
|
||||
|
||||
def forward(self, inp):
|
||||
x = inp
|
||||
|
||||
x = self.norm1(x)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.sg(x)
|
||||
x = x * self.sca(x)
|
||||
x = self.conv3(x)
|
||||
|
||||
x = self.dropout1(x)
|
||||
|
||||
y = inp + x * self.beta
|
||||
|
||||
x = self.conv4(self.norm2(y))
|
||||
x = self.sg(x)
|
||||
x = self.conv5(x)
|
||||
|
||||
x = self.dropout2(x)
|
||||
|
||||
return y + x * self.gamma
|
||||
|
||||
|
||||
class NAFNet(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_channel=3,
|
||||
width=16,
|
||||
middle_blk_num=1,
|
||||
enc_blk_nums=[],
|
||||
dec_blk_nums=[]):
|
||||
super().__init__()
|
||||
|
||||
self.intro = nn.Conv2d(
|
||||
in_channels=img_channel,
|
||||
out_channels=width,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
self.ending = nn.Conv2d(
|
||||
in_channels=width,
|
||||
out_channels=img_channel,
|
||||
kernel_size=3,
|
||||
padding=1,
|
||||
stride=1,
|
||||
groups=1,
|
||||
bias=True)
|
||||
|
||||
self.encoders = nn.ModuleList()
|
||||
self.decoders = nn.ModuleList()
|
||||
self.middle_blks = nn.ModuleList()
|
||||
self.ups = nn.ModuleList()
|
||||
self.downs = nn.ModuleList()
|
||||
|
||||
chan = width
|
||||
for num in enc_blk_nums:
|
||||
self.encoders.append(
|
||||
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
|
||||
self.downs.append(nn.Conv2d(chan, 2 * chan, 2, 2))
|
||||
chan = chan * 2
|
||||
|
||||
self.middle_blks = \
|
||||
nn.Sequential(
|
||||
*[NAFBlock(chan) for _ in range(middle_blk_num)]
|
||||
)
|
||||
|
||||
for num in dec_blk_nums:
|
||||
self.ups.append(
|
||||
nn.Sequential(
|
||||
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
||||
nn.PixelShuffle(2)))
|
||||
chan = chan // 2
|
||||
self.decoders.append(
|
||||
nn.Sequential(*[NAFBlock(chan) for _ in range(num)]))
|
||||
|
||||
self.padder_size = 2**len(self.encoders)
|
||||
|
||||
def forward(self, inp):
|
||||
B, C, H, W = inp.shape
|
||||
inp = self.check_image_size(inp)
|
||||
|
||||
x = self.intro(inp)
|
||||
|
||||
encs = []
|
||||
|
||||
for encoder, down in zip(self.encoders, self.downs):
|
||||
x = encoder(x)
|
||||
encs.append(x)
|
||||
x = down(x)
|
||||
|
||||
x = self.middle_blks(x)
|
||||
|
||||
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
||||
x = up(x)
|
||||
x = x + enc_skip
|
||||
x = decoder(x)
|
||||
|
||||
x = self.ending(x)
|
||||
x = x + inp
|
||||
|
||||
return x[:, :, :H, :W]
|
||||
|
||||
def check_image_size(self, x):
|
||||
_, _, h, w = x.size()
|
||||
mod_pad_h = (self.padder_size
|
||||
- h % self.padder_size) % self.padder_size
|
||||
mod_pad_w = (self.padder_size
|
||||
- w % self.padder_size) % self.padder_size
|
||||
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
||||
return x
|
||||
|
||||
|
||||
class PSNRLoss(nn.Module):
|
||||
|
||||
def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
|
||||
super(PSNRLoss, self).__init__()
|
||||
assert reduction == 'mean'
|
||||
self.loss_weight = loss_weight
|
||||
self.scale = 10 / np.log(10)
|
||||
self.toY = toY
|
||||
self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
|
||||
self.first = True
|
||||
|
||||
def forward(self, pred, target):
|
||||
assert len(pred.size()) == 4
|
||||
if self.toY:
|
||||
if self.first:
|
||||
self.coef = self.coef.to(pred.device)
|
||||
self.first = False
|
||||
|
||||
pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
|
||||
target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
|
||||
|
||||
pred, target = pred / 255., target / 255.
|
||||
pass
|
||||
assert len(pred.size()) == 4
|
||||
|
||||
return self.loss_weight * self.scale * torch.log((
|
||||
(pred - target)**2).mean(dim=(1, 2, 3)) + 1e-8).mean()
|
||||
42
modelscope/models/cv/image_denoise/nafnet/arch_util.py
Normal file
42
modelscope/models/cv/image_denoise/nafnet/arch_util.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class LayerNormFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, x, weight, bias, eps):
|
||||
ctx.eps = eps
|
||||
N, C, H, W = x.size()
|
||||
mu = x.mean(1, keepdim=True)
|
||||
var = (x - mu).pow(2).mean(1, keepdim=True)
|
||||
y = (x - mu) / (var + eps).sqrt()
|
||||
ctx.save_for_backward(y, var, weight)
|
||||
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
||||
return y
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
eps = ctx.eps
|
||||
|
||||
N, C, H, W = grad_output.size()
|
||||
y, var, weight = ctx.saved_variables
|
||||
g = grad_output * weight.view(1, C, 1, 1)
|
||||
mean_g = g.mean(dim=1, keepdim=True)
|
||||
|
||||
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
||||
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
||||
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(
|
||||
dim=0), grad_output.sum(dim=3).sum(dim=2).sum(dim=0), None
|
||||
|
||||
|
||||
class LayerNorm2d(nn.Module):
|
||||
|
||||
def __init__(self, channels, eps=1e-6):
|
||||
super(LayerNorm2d, self).__init__()
|
||||
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
||||
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
||||
119
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
Normal file
119
modelscope/models/cv/image_denoise/nafnet_for_image_denoise.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .nafnet.NAFNet_arch import NAFNet, PSNRLoss
|
||||
|
||||
logger = get_logger()
|
||||
__all__ = ['NAFNetForImageDenoise']
|
||||
|
||||
|
||||
@MODELS.register_module(Tasks.image_denoise, module_name=Models.nafnet)
|
||||
class NAFNetForImageDenoise(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the image denoise model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
self.model_dir = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
model_path = os.path.join(model_dir, ModelFile.TORCH_MODEL_FILE)
|
||||
self.model = NAFNet(**self.config.model.network_g)
|
||||
self.loss = PSNRLoss()
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
self.model = self.model.to(self._device)
|
||||
self.model = self._load_pretrained(self.model, model_path)
|
||||
|
||||
if self.training:
|
||||
self.model.train()
|
||||
else:
|
||||
self.model.eval()
|
||||
|
||||
def _load_pretrained(self,
|
||||
net,
|
||||
load_path,
|
||||
strict=True,
|
||||
param_key='params'):
|
||||
if isinstance(net, (DataParallel, DistributedDataParallel)):
|
||||
net = net.module
|
||||
load_net = torch.load(
|
||||
load_path, map_location=lambda storage, loc: storage)
|
||||
if param_key is not None:
|
||||
if param_key not in load_net and 'params' in load_net:
|
||||
param_key = 'params'
|
||||
logger.info(
|
||||
f'Loading: {param_key} does not exist, use params.')
|
||||
if param_key in load_net:
|
||||
load_net = load_net[param_key]
|
||||
logger.info(
|
||||
f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].'
|
||||
)
|
||||
# remove unnecessary 'module.'
|
||||
for k, v in deepcopy(load_net).items():
|
||||
if k.startswith('module.'):
|
||||
load_net[k[7:]] = v
|
||||
load_net.pop(k)
|
||||
net.load_state_dict(load_net, strict=strict)
|
||||
logger.info('load model done.')
|
||||
return net
|
||||
|
||||
def _train_forward(self, input: Tensor,
|
||||
target: Tensor) -> Dict[str, Tensor]:
|
||||
preds = self.model(input)
|
||||
return {'loss': self.loss(preds, target)}
|
||||
|
||||
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
|
||||
return {'outputs': self.model(input).clamp(0, 1)}
|
||||
|
||||
def _evaluate_postprocess(self, input: Tensor,
|
||||
target: Tensor) -> Dict[str, list]:
|
||||
preds = self.model(input)
|
||||
preds = list(torch.split(preds, 1, 0))
|
||||
targets = list(torch.split(target, 1, 0))
|
||||
|
||||
preds = [(pred.data * 255.).squeeze(0).permute(
|
||||
1, 2, 0).cpu().numpy().astype(np.uint8) for pred in preds]
|
||||
targets = [(target.data * 255.).squeeze(0).permute(
|
||||
1, 2, 0).cpu().numpy().astype(np.uint8) for target in targets]
|
||||
|
||||
return {'pred': preds, 'target': targets}
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
Tensor]) -> Dict[str, Union[list, Tensor]]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: results
|
||||
"""
|
||||
for key, value in inputs.items():
|
||||
inputs[key] = inputs[key].to(self._device)
|
||||
if self.training:
|
||||
return self._train_forward(**inputs)
|
||||
elif 'target' in inputs:
|
||||
return self._evaluate_postprocess(**inputs)
|
||||
else:
|
||||
return self._inference_forward(**inputs)
|
||||
152
modelscope/msdatasets/image_denoise_data/data_utils.py
Normal file
152
modelscope/msdatasets/image_denoise_data/data_utils.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
import os
|
||||
from os import path as osp
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .transforms import mod_crop
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def scandir(dir_path, keyword=None, recursive=False, full_path=False):
|
||||
"""Scan a directory to find the interested files.
|
||||
Args:
|
||||
dir_path (str): Path of the directory.
|
||||
keyword (str | tuple(str), optional): File keyword that we are
|
||||
interested in. Default: None.
|
||||
recursive (bool, optional): If set to True, recursively scan the
|
||||
directory. Default: False.
|
||||
full_path (bool, optional): If set to True, include the dir_path.
|
||||
Default: False.
|
||||
Returns:
|
||||
A generator for all the interested files with relative pathes.
|
||||
"""
|
||||
|
||||
if (keyword is not None) and not isinstance(keyword, (str, tuple)):
|
||||
raise TypeError('"suffix" must be a string or tuple of strings')
|
||||
|
||||
root = dir_path
|
||||
|
||||
def _scandir(dir_path, keyword, recursive):
|
||||
for entry in os.scandir(dir_path):
|
||||
if not entry.name.startswith('.') and entry.is_file():
|
||||
if full_path:
|
||||
return_path = entry.path
|
||||
else:
|
||||
return_path = osp.relpath(entry.path, root)
|
||||
|
||||
if keyword is None:
|
||||
yield return_path
|
||||
elif keyword in return_path:
|
||||
yield return_path
|
||||
else:
|
||||
if recursive:
|
||||
yield from _scandir(
|
||||
entry.path, keyword=keyword, recursive=recursive)
|
||||
else:
|
||||
continue
|
||||
|
||||
return _scandir(dir_path, keyword=keyword, recursive=recursive)
|
||||
|
||||
|
||||
def padding(img_lq, img_gt, gt_size):
|
||||
h, w, _ = img_lq.shape
|
||||
|
||||
h_pad = max(0, gt_size - h)
|
||||
w_pad = max(0, gt_size - w)
|
||||
|
||||
if h_pad == 0 and w_pad == 0:
|
||||
return img_lq, img_gt
|
||||
|
||||
img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
|
||||
return img_lq, img_gt
|
||||
|
||||
|
||||
def read_img_seq(path, require_mod_crop=False, scale=1):
|
||||
"""Read a sequence of images from a given folder path.
|
||||
Args:
|
||||
path (list[str] | str): List of image paths or image folder path.
|
||||
require_mod_crop (bool): Require mod crop for each image.
|
||||
Default: False.
|
||||
scale (int): Scale factor for mod_crop. Default: 1.
|
||||
Returns:
|
||||
Tensor: size (t, c, h, w), RGB, [0, 1].
|
||||
"""
|
||||
if isinstance(path, list):
|
||||
img_paths = path
|
||||
else:
|
||||
img_paths = sorted(list(scandir(path, full_path=True)))
|
||||
imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
|
||||
if require_mod_crop:
|
||||
imgs = [mod_crop(img, scale) for img in imgs]
|
||||
imgs = img2tensor(imgs, bgr2rgb=True, float32=True)
|
||||
imgs = torch.stack(imgs, dim=0)
|
||||
return imgs
|
||||
|
||||
|
||||
def paired_paths_from_folder(folders, keys, filename_tmpl):
|
||||
"""Generate paired paths from folders.
|
||||
Args:
|
||||
folders (list[str]): A list of folder path. The order of list should
|
||||
be [input_folder, gt_folder].
|
||||
keys (list[str]): A list of keys identifying folders. The order should
|
||||
be in consistent with folders, e.g., ['lq', 'gt'].
|
||||
filename_tmpl (str): Template for each filename. Note that the
|
||||
template excludes the file extension. Usually the filename_tmpl is
|
||||
for files in the input folder.
|
||||
Returns:
|
||||
list[str]: Returned path list.
|
||||
"""
|
||||
assert len(folders) == 2, (
|
||||
'The len of folders should be 2 with [input_folder, gt_folder]. '
|
||||
f'But got {len(folders)}')
|
||||
assert len(keys) == 2, (
|
||||
'The len of keys should be 2 with [input_key, gt_key]. '
|
||||
f'But got {len(keys)}')
|
||||
input_folder, gt_folder = folders
|
||||
input_key, gt_key = keys
|
||||
|
||||
input_paths = list(scandir(input_folder, keyword='NOISY', recursive=True))
|
||||
gt_paths = list(scandir(gt_folder, keyword='GT', recursive=True))
|
||||
assert len(input_paths) == len(gt_paths), (
|
||||
f'{input_key} and {gt_key} datasets have different number of images: '
|
||||
f'{len(input_paths)}, {len(gt_paths)}.')
|
||||
paths = []
|
||||
for idx in range(len(gt_paths)):
|
||||
gt_path = os.path.join(gt_folder, gt_paths[idx])
|
||||
input_path = os.path.join(input_folder, gt_path.replace('GT', 'NOISY'))
|
||||
|
||||
paths.append(
|
||||
dict([(f'{input_key}_path', input_path),
|
||||
(f'{gt_key}_path', gt_path)]))
|
||||
return paths
|
||||
@@ -0,0 +1,78 @@
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from torch.utils import data
|
||||
|
||||
from .data_utils import img2tensor, padding, paired_paths_from_folder
|
||||
from .transforms import augment, paired_random_crop
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
|
||||
|
||||
|
||||
class PairedImageDataset(data.Dataset):
|
||||
"""Paired image dataset for image restoration.
|
||||
"""
|
||||
|
||||
def __init__(self, opt, root, is_train):
|
||||
super(PairedImageDataset, self).__init__()
|
||||
self.opt = opt
|
||||
self.is_train = is_train
|
||||
self.gt_folder, self.lq_folder = os.path.join(
|
||||
root, opt.dataroot_gt), os.path.join(root, opt.dataroot_lq)
|
||||
|
||||
if opt.filename_tmpl is not None:
|
||||
self.filename_tmpl = opt.filename_tmpl
|
||||
else:
|
||||
self.filename_tmpl = '{}'
|
||||
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder],
|
||||
['lq', 'gt'], self.filename_tmpl)
|
||||
|
||||
def __getitem__(self, index):
|
||||
scale = self.opt.scale
|
||||
|
||||
# Load gt and lq images. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32.
|
||||
gt_path = self.paths[index]['gt_path']
|
||||
img_gt = default_loader(gt_path)
|
||||
lq_path = self.paths[index]['lq_path']
|
||||
img_lq = default_loader(lq_path)
|
||||
|
||||
# augmentation for training
|
||||
# if self.is_train:
|
||||
gt_size = self.opt.gt_size
|
||||
# padding
|
||||
img_gt, img_lq = padding(img_gt, img_lq, gt_size)
|
||||
|
||||
# random crop
|
||||
img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale)
|
||||
|
||||
# flip, rotation
|
||||
img_gt, img_lq = augment([img_gt, img_lq], self.opt.use_flip,
|
||||
self.opt.use_rot)
|
||||
|
||||
# BGR to RGB, HWC to CHW, numpy to tensor
|
||||
img_gt, img_lq = img2tensor([img_gt, img_lq],
|
||||
bgr2rgb=True,
|
||||
float32=True)
|
||||
|
||||
return {
|
||||
'input': img_lq,
|
||||
'target': img_gt,
|
||||
'input_path': lq_path,
|
||||
'target_path': gt_path
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def to_torch_dataset(
|
||||
self,
|
||||
columns: Union[str, List[str]] = None,
|
||||
preprocessors: Union[Callable, List[Callable]] = None,
|
||||
**format_kwargs,
|
||||
):
|
||||
return self
|
||||
96
modelscope/msdatasets/image_denoise_data/transforms.py
Normal file
96
modelscope/msdatasets/image_denoise_data/transforms.py
Normal file
@@ -0,0 +1,96 @@
|
||||
# Modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/data/transforms.py
|
||||
|
||||
import random
|
||||
|
||||
|
||||
def mod_crop(img, scale):
|
||||
"""Mod crop images, used during testing.
|
||||
Args:
|
||||
img (ndarray): Input image.
|
||||
scale (int): Scale factor.
|
||||
Returns:
|
||||
ndarray: Result image.
|
||||
"""
|
||||
img = img.copy()
|
||||
if img.ndim in (2, 3):
|
||||
h, w = img.shape[0], img.shape[1]
|
||||
h_remainder, w_remainder = h % scale, w % scale
|
||||
img = img[:h - h_remainder, :w - w_remainder, ...]
|
||||
else:
|
||||
raise ValueError(f'Wrong img ndim: {img.ndim}.')
|
||||
return img
|
||||
|
||||
|
||||
def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale):
|
||||
"""Paired random crop.
|
||||
|
||||
It crops lists of lq and gt images with corresponding locations.
|
||||
|
||||
Args:
|
||||
img_gts (list[ndarray] | ndarray): GT images.
|
||||
img_lqs (list[ndarray] | ndarray): LQ images.
|
||||
gt_patch_size (int): GT patch size.
|
||||
scale (int): Scale factor.
|
||||
|
||||
Returns:
|
||||
list[ndarray] | ndarray: GT images and LQ images.
|
||||
"""
|
||||
|
||||
if not isinstance(img_gts, list):
|
||||
img_gts = [img_gts]
|
||||
if not isinstance(img_lqs, list):
|
||||
img_lqs = [img_lqs]
|
||||
|
||||
h_lq, w_lq, _ = img_lqs[0].shape
|
||||
h_gt, w_gt, _ = img_gts[0].shape
|
||||
lq_patch_size = gt_patch_size // scale
|
||||
|
||||
# randomly choose top and left coordinates for lq patch
|
||||
top = random.randint(0, h_lq - lq_patch_size)
|
||||
left = random.randint(0, w_lq - lq_patch_size)
|
||||
|
||||
# crop lq patch
|
||||
img_lqs = [
|
||||
v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
|
||||
for v in img_lqs
|
||||
]
|
||||
|
||||
# crop corresponding gt patch
|
||||
top_gt, left_gt = int(top * scale), int(left * scale)
|
||||
img_gts = [
|
||||
v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
|
||||
for v in img_gts
|
||||
]
|
||||
if len(img_gts) == 1:
|
||||
img_gts = img_gts[0]
|
||||
if len(img_lqs) == 1:
|
||||
img_lqs = img_lqs[0]
|
||||
return img_gts, img_lqs
|
||||
|
||||
|
||||
def augment(imgs, hflip=True, rotation=True, vflip=False):
|
||||
"""Augment: horizontal flips | rotate
|
||||
|
||||
All the images in the list use the same augmentation.
|
||||
"""
|
||||
hflip = hflip and random.random() < 0.5
|
||||
if vflip or rotation:
|
||||
vflip = random.random() < 0.5
|
||||
rot90 = rotation and random.random() < 0.5
|
||||
|
||||
def _augment(img):
|
||||
if hflip: # horizontal
|
||||
img = img[:, ::-1, :].copy()
|
||||
if vflip: # vertical
|
||||
img = img[::-1, :, :].copy()
|
||||
if rot90:
|
||||
img = img.transpose(1, 0, 2)
|
||||
return img
|
||||
|
||||
if not isinstance(imgs, list):
|
||||
imgs = [imgs]
|
||||
imgs = [_augment(img) for img in imgs]
|
||||
if len(imgs) == 1:
|
||||
imgs = imgs[0]
|
||||
|
||||
return imgs
|
||||
@@ -74,6 +74,7 @@ TASK_OUTPUTS = {
|
||||
Tasks.image_editing: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_matting: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_generation: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_denoise: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_colorization: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.face_image_generation: [OutputKeys.OUTPUT_IMG],
|
||||
Tasks.image_super_resolution: [OutputKeys.OUTPUT_IMG],
|
||||
|
||||
@@ -35,6 +35,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
), # TODO: revise back after passing the pr
|
||||
Tasks.image_matting: (Pipelines.image_matting,
|
||||
'damo/cv_unet_image-matting'),
|
||||
Tasks.image_denoise: (Pipelines.image_denoise,
|
||||
'damo/cv_nafnet_image-denoise_sidd'),
|
||||
Tasks.text_classification: (Pipelines.sentiment_analysis,
|
||||
'damo/bert-base-sst2'),
|
||||
Tasks.text_generation: (Pipelines.text_generation,
|
||||
|
||||
@@ -6,6 +6,7 @@ try:
|
||||
from .action_recognition_pipeline import ActionRecognitionPipeline
|
||||
from .animal_recog_pipeline import AnimalRecogPipeline
|
||||
from .cmdssl_video_embedding_pipleline import CMDSSLVideoEmbeddingPipeline
|
||||
from .image_denoise_pipeline import ImageDenoisePipeline
|
||||
from .image_color_enhance_pipeline import ImageColorEnhancePipeline
|
||||
from .virtual_tryon_pipeline import VirtualTryonPipeline
|
||||
from .image_colorization_pipeline import ImageColorizationPipeline
|
||||
|
||||
111
modelscope/pipelines/cv/image_denoise_pipeline.py
Normal file
111
modelscope/pipelines/cv/image_denoise_pipeline.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models import Model
|
||||
from modelscope.models.cv import NAFNetForImageDenoise
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input
|
||||
from modelscope.preprocessors import ImageDenoisePreprocessor, LoadImage
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from ..base import Pipeline
|
||||
from ..builder import PIPELINES
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = ['ImageDenoisePipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.image_denoise, module_name=Pipelines.image_denoise)
|
||||
class ImageDenoisePipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[NAFNetForImageDenoise, str],
|
||||
preprocessor: Optional[ImageDenoisePreprocessor] = None,
|
||||
**kwargs):
|
||||
"""
|
||||
use `model` and `preprocessor` to create a cv image denoise pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
model = model if isinstance(
|
||||
model, NAFNetForImageDenoise) else Model.from_pretrained(model)
|
||||
model.eval()
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
self.config = model.config
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.model = model
|
||||
logger.info('load image denoise model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
img = LoadImage.convert_to_img(input)
|
||||
test_transforms = transforms.Compose([transforms.ToTensor()])
|
||||
img = test_transforms(img)
|
||||
result = {'img': img.unsqueeze(0).to(self._device)}
|
||||
return result
|
||||
|
||||
def crop_process(self, input):
|
||||
output = torch.zeros_like(input) # [1, C, H, W]
|
||||
# determine crop_h and crop_w
|
||||
ih, iw = input.shape[-2:]
|
||||
crop_rows, crop_cols = max(ih // 512, 1), max(iw // 512, 1)
|
||||
overlap = 16
|
||||
|
||||
step_h, step_w = ih // crop_rows, iw // crop_cols
|
||||
for y in range(crop_rows):
|
||||
for x in range(crop_cols):
|
||||
crop_y = step_h * y
|
||||
crop_x = step_w * x
|
||||
|
||||
crop_h = step_h if y < crop_rows - 1 else ih - crop_y
|
||||
crop_w = step_w if x < crop_cols - 1 else iw - crop_x
|
||||
|
||||
crop_frames = input[:, :,
|
||||
max(0, crop_y - overlap
|
||||
):min(crop_y + crop_h + overlap, ih),
|
||||
max(0, crop_x - overlap
|
||||
):min(crop_x + crop_w
|
||||
+ overlap, iw)].contiguous()
|
||||
h_start = overlap if max(0, crop_y - overlap) > 0 else 0
|
||||
w_start = overlap if max(0, crop_x - overlap) > 0 else 0
|
||||
h_end = h_start + crop_h if min(crop_y + crop_h
|
||||
+ overlap, ih) < ih else ih
|
||||
w_end = w_start + crop_w if min(crop_x + crop_w
|
||||
+ overlap, iw) < iw else iw
|
||||
|
||||
output[:, :, crop_y:crop_y + crop_h,
|
||||
crop_x:crop_x + crop_w] = self.model._inference_forward(
|
||||
crop_frames)['outputs'][:, :, h_start:h_end,
|
||||
w_start:w_end]
|
||||
return output
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
|
||||
def set_phase(model, is_train):
|
||||
if is_train:
|
||||
model.train()
|
||||
else:
|
||||
model.eval()
|
||||
|
||||
is_train = False
|
||||
set_phase(self.model, is_train)
|
||||
with torch.no_grad():
|
||||
output = self.crop_process(input['img']) # output Tensor
|
||||
|
||||
return {'output_tensor': output}
|
||||
|
||||
def postprocess(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
output_img = (input['output_tensor'].squeeze(0) * 255).cpu().permute(
|
||||
1, 2, 0).numpy().astype('uint8')
|
||||
return {OutputKeys.OUTPUT_IMG: output_img}
|
||||
@@ -21,6 +21,7 @@ try:
|
||||
from .space.dialog_state_tracking_preprocessor import * # noqa F403
|
||||
from .image import ImageColorEnhanceFinetunePreprocessor
|
||||
from .image import ImageInstanceSegmentationPreprocessor
|
||||
from .image import ImageDenoisePreprocessor
|
||||
except ModuleNotFoundError as e:
|
||||
if str(e) == "No module named 'tensorflow'":
|
||||
print(TENSORFLOW_IMPORT_ERROR.format('tts'))
|
||||
|
||||
@@ -138,6 +138,31 @@ class ImageColorEnhanceFinetunePreprocessor(Preprocessor):
|
||||
return data
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.cv, module_name=Preprocessors.image_denoie_preprocessor)
|
||||
class ImageDenoisePreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""
|
||||
|
||||
Args:
|
||||
model_dir (str): model path
|
||||
"""
|
||||
super().__init__(*args, **kwargs)
|
||||
self.model_dir: str = model_dir
|
||||
|
||||
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""process the raw input data
|
||||
|
||||
Args:
|
||||
data Dict[str, Any]
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: the preprocessed data
|
||||
"""
|
||||
return data
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
Fields.cv,
|
||||
module_name=Preprocessors.image_instance_segmentation_preprocessor)
|
||||
|
||||
@@ -24,6 +24,7 @@ class CVTasks(object):
|
||||
image_editing = 'image-editing'
|
||||
image_generation = 'image-generation'
|
||||
image_matting = 'image-matting'
|
||||
image_denoise = 'image-denoise'
|
||||
ocr_detection = 'ocr-detection'
|
||||
action_recognition = 'action-recognition'
|
||||
video_embedding = 'video-embedding'
|
||||
|
||||
59
tests/pipelines/test_image_denoise.py
Normal file
59
tests/pipelines/test_image_denoise.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import unittest
|
||||
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import ImageDenoisePipeline, pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class ImageDenoiseTest(unittest.TestCase):
|
||||
model_id = 'damo/cv_nafnet_image-denoise_sidd'
|
||||
demo_image_path = 'data/test/images/noisy-demo-1.png'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
pipeline = ImageDenoisePipeline(cache_path)
|
||||
denoise_img = pipeline(
|
||||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
|
||||
denoise_img = Image.fromarray(denoise_img)
|
||||
w, h = denoise_img.size
|
||||
print('pipeline: the shape of output_img is {}x{}'.format(h, w))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
pipeline_ins = pipeline(task=Tasks.image_denoise, model=model)
|
||||
denoise_img = pipeline_ins(
|
||||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
|
||||
denoise_img = Image.fromarray(denoise_img)
|
||||
w, h = denoise_img.size
|
||||
print('pipeline: the shape of output_img is {}x{}'.format(h, w))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_name(self):
|
||||
pipeline_ins = pipeline(task=Tasks.image_denoise, model=self.model_id)
|
||||
denoise_img = pipeline_ins(
|
||||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
|
||||
denoise_img = Image.fromarray(denoise_img)
|
||||
w, h = denoise_img.size
|
||||
print('pipeline: the shape of output_img is {}x{}'.format(h, w))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.image_denoise)
|
||||
denoise_img = pipeline_ins(
|
||||
input=self.demo_image_path)[OutputKeys.OUTPUT_IMG]
|
||||
denoise_img = Image.fromarray(denoise_img)
|
||||
w, h = denoise_img.size
|
||||
print('pipeline: the shape of output_img is {}x{}'.format(h, w))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
74
tests/trainers/test_image_denoise_trainer.py
Normal file
74
tests/trainers/test_image_denoise_trainer.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import NAFNetForImageDenoise
|
||||
from modelscope.msdatasets.image_denoise_data.image_denoise_dataset import \
|
||||
PairedImageDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ImageDenoiseTrainerTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.model_id = 'damo/cv_nafnet_image-denoise_sidd'
|
||||
self.cache_path = snapshot_download(self.model_id)
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.cache_path, ModelFile.CONFIGURATION))
|
||||
self.dataset_train = PairedImageDataset(
|
||||
self.config.dataset, self.cache_path, is_train=True)
|
||||
self.dataset_val = PairedImageDataset(
|
||||
self.config.dataset, self.cache_path, is_train=False)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
train_dataset=self.dataset_train,
|
||||
eval_dataset=self.dataset_val,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_trainer_with_model_and_args(self):
|
||||
model = NAFNetForImageDenoise.from_pretrained(self.cache_path)
|
||||
kwargs = dict(
|
||||
cfg_file=os.path.join(self.cache_path, ModelFile.CONFIGURATION),
|
||||
model=model,
|
||||
train_dataset=self.dataset_train,
|
||||
eval_dataset=self.dataset_val,
|
||||
max_epochs=2,
|
||||
work_dir=self.tmp_dir)
|
||||
trainer = build_trainer(default_args=kwargs)
|
||||
trainer.train()
|
||||
results_files = os.listdir(self.tmp_dir)
|
||||
self.assertIn(f'{trainer.timestamp}.log.json', results_files)
|
||||
for i in range(2):
|
||||
self.assertIn(f'epoch_{i+1}.pth', results_files)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user