mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add video deinterlace model
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11508847 * add video deinterlace model * resolve conflict for video deinterlace * fix CR problems
This commit is contained in:
3
data/test/videos/video_deinterlace_test.mp4
Normal file
3
data/test/videos/video_deinterlace_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:9941ac4a5dd0d9eea5d33ce0009da34d0c93c64ed062479e6c8efb4788e8ef7c
|
||||
size 522972
|
||||
@@ -79,6 +79,7 @@ class Models(object):
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
|
||||
vision_middleware = 'vision-middleware'
|
||||
video_stabilization = 'video-stabilization'
|
||||
@@ -329,6 +330,7 @@ class Pipelines(object):
|
||||
vision_middleware_multi_task = 'vision-middleware-multi-task'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
image_matching = 'image-matching'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
@@ -674,6 +676,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.video_frame_interpolation:
|
||||
(Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
Tasks.video_deinterlace: (Pipelines.video_deinterlace,
|
||||
'damo/cv_unet_video-deinterlace'),
|
||||
Tasks.human_wholebody_keypoint:
|
||||
(Pipelines.human_wholebody_keypoint,
|
||||
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
|
||||
|
||||
@@ -19,10 +19,10 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
referring_video_object_segmentation,
|
||||
robust_image_classification, salient_detection,
|
||||
shop_segmentation, stream_yolo, super_resolution,
|
||||
video_frame_interpolation, video_object_segmentation,
|
||||
video_panoptic_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
video_super_resolution, virual_tryon, vision_middleware,
|
||||
vop_retrieval)
|
||||
video_deinterlace, video_frame_interpolation,
|
||||
video_object_segmentation, video_panoptic_segmentation,
|
||||
video_single_object_tracking, video_stabilization,
|
||||
video_summarization, video_super_resolution, virual_tryon,
|
||||
vision_middleware, vop_retrieval)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch.cuda
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.video_deinterlace.deinterlace_arch import \
|
||||
DeinterlaceNet
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
__all__ = ['UNetForVideoDeinterlace']
|
||||
|
||||
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace('module.', ''): v
|
||||
for k, v in param.items() if 'module.' in k
|
||||
}
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_deinterlace, module_name=Models.video_deinterlace)
|
||||
class UNetForVideoDeinterlace(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the video deinterlace model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
self.model_dir = model_dir
|
||||
|
||||
frenet_path = os.path.join(model_dir, 'deinterlace_fre.pth')
|
||||
enhnet_path = os.path.join(model_dir, 'deinterlace_mf.pth')
|
||||
|
||||
self.model = DeinterlaceNet()
|
||||
self._load_pretrained(frenet_path, enhnet_path)
|
||||
|
||||
def _load_pretrained(self, frenet_path, enhnet_path):
|
||||
state_dict_frenet = torch.load(frenet_path, map_location=self._device)
|
||||
state_dict_enhnet = torch.load(enhnet_path, map_location=self._device)
|
||||
|
||||
self.model.frenet.load_state_dict(state_dict_frenet, strict=True)
|
||||
self.model.enhnet.load_state_dict(state_dict_enhnet, strict=True)
|
||||
logger.info('load model done.')
|
||||
|
||||
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
|
||||
return {'output': self.model(input)}
|
||||
|
||||
def _evaluate_postprocess(self, input: Tensor,
|
||||
target: Tensor) -> Dict[str, list]:
|
||||
preds = self.model(input)
|
||||
del input
|
||||
torch.cuda.empty_cache()
|
||||
return {'pred': preds, 'target': target}
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
Tensor]) -> Dict[str, Union[list, Tensor]]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: results
|
||||
"""
|
||||
|
||||
if 'target' in inputs:
|
||||
return self._evaluate_postprocess(**inputs)
|
||||
else:
|
||||
return self._inference_forward(**inputs)
|
||||
21
modelscope/models/cv/video_deinterlace/__init__.py
Normal file
21
modelscope/models/cv/video_deinterlace/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .deinterlace_arch import \
|
||||
DeinterlaceNet
|
||||
|
||||
else:
|
||||
_import_structure = {'deinterlace_arch': ['DeinterlaceNet']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
27
modelscope/models/cv/video_deinterlace/deinterlace_arch.py
Normal file
27
modelscope/models/cv/video_deinterlace/deinterlace_arch.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.video_deinterlace.models.enh import DeinterlaceEnh
|
||||
from modelscope.models.cv.video_deinterlace.models.fre import DeinterlaceFre
|
||||
|
||||
|
||||
class DeinterlaceNet(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(DeinterlaceNet, self).__init__()
|
||||
self.frenet = DeinterlaceFre()
|
||||
self.enhnet = DeinterlaceEnh()
|
||||
|
||||
def forward(self, frames):
|
||||
self.frenet.eval()
|
||||
self.enhnet.eval()
|
||||
with torch.no_grad():
|
||||
frame1, frame2, frame3 = frames
|
||||
|
||||
F1_out = self.frenet(frame1)
|
||||
F2_out = self.frenet(frame2)
|
||||
F3_out = self.frenet(frame3)
|
||||
|
||||
out = self.enhnet([F1_out, F2_out, F3_out])
|
||||
|
||||
return out
|
||||
97
modelscope/models/cv/video_deinterlace/models/archs.py
Normal file
97
modelscope/models/cv/video_deinterlace/models/archs.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class DoubleConv(nn.Module):
|
||||
"""(convolution => [BN] => ReLU) * 2"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
self.double_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
return self.double_conv(x)
|
||||
|
||||
|
||||
class TripleConv(nn.Module):
|
||||
"""(convolution => [BN] => ReLU) * 3"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, mid_channels=None):
|
||||
super().__init__()
|
||||
if not mid_channels:
|
||||
mid_channels = out_channels
|
||||
self.triple_conv = nn.Sequential(
|
||||
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True))
|
||||
|
||||
def forward(self, x):
|
||||
return self.triple_conv(x)
|
||||
|
||||
|
||||
class DownConv(nn.Module):
|
||||
"""Downscaling with avgpool then double/triple conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, num_conv=2):
|
||||
super().__init__()
|
||||
if num_conv == 2:
|
||||
self.pool_conv = nn.Sequential(
|
||||
nn.AvgPool2d(2), DoubleConv(in_channels, out_channels))
|
||||
else:
|
||||
self.pool_conv = nn.Sequential(
|
||||
nn.AvgPool2d(2), TripleConv(in_channels, out_channels))
|
||||
|
||||
def forward(self, x):
|
||||
return self.pool_conv(x)
|
||||
|
||||
|
||||
class UpCatConv(nn.Module):
|
||||
"""Upscaling then double conv"""
|
||||
|
||||
def __init__(self, in_channels, out_channels, bilinear=False):
|
||||
super().__init__()
|
||||
|
||||
if bilinear:
|
||||
self.up = nn.Upsample(
|
||||
scale_factor=2, mode='bilinear', align_corners=True)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
else:
|
||||
self.up = nn.Upsample(
|
||||
scale_factor=2, mode='nearest', align_corners=None)
|
||||
self.conv = DoubleConv(in_channels, out_channels)
|
||||
|
||||
self.subpixel = nn.PixelShuffle(2)
|
||||
|
||||
def interpolate(self, x):
|
||||
tensor_temp = x
|
||||
for i in range(3):
|
||||
tensor_temp = torch.cat((tensor_temp, x), 1)
|
||||
x = tensor_temp
|
||||
x = self.subpixel(x)
|
||||
return x
|
||||
|
||||
def forward(self, x1, x2):
|
||||
|
||||
x1 = self.interpolate(x1)
|
||||
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
x1 = F.pad(
|
||||
x1,
|
||||
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
|
||||
x = torch.cat([x2, x1], dim=1)
|
||||
return self.conv(x)
|
||||
@@ -0,0 +1,47 @@
|
||||
# The implementation is adopted from Deep Fourier Upsampling,
|
||||
# made publicly available at https://github.com/manman1995/Deep-Fourier-Upsampling
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.fft
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class freup_Periodicpadding(nn.Module):
|
||||
|
||||
def __init__(self, channels):
|
||||
super(freup_Periodicpadding, self).__init__()
|
||||
|
||||
self.amp_fuse = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, 1, 0),
|
||||
nn.LeakyReLU(0.1, inplace=False),
|
||||
nn.Conv2d(channels, channels, 1, 1, 0))
|
||||
self.pha_fuse = nn.Sequential(
|
||||
nn.Conv2d(channels, channels, 1, 1, 0),
|
||||
nn.LeakyReLU(0.1, inplace=False),
|
||||
nn.Conv2d(channels, channels, 1, 1, 0))
|
||||
|
||||
self.post = nn.Conv2d(channels, channels, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
N, C, H, W = x.shape
|
||||
|
||||
fft_x = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
|
||||
mag_x = torch.abs(fft_x)
|
||||
pha_x = torch.angle(fft_x).detach()
|
||||
|
||||
Mag = self.amp_fuse(mag_x)
|
||||
Pha = self.pha_fuse(pha_x)
|
||||
|
||||
amp_fuse = Mag.repeat(1, 1, 2, 2)
|
||||
pha_fuse = Pha.repeat(1, 1, 2, 2)
|
||||
|
||||
real = amp_fuse * torch.cos(pha_fuse)
|
||||
imag = amp_fuse * torch.sin(pha_fuse)
|
||||
out = torch.complex(real, imag)
|
||||
|
||||
output = torch.fft.ifft(torch.fft.ifft(out, dim=0), dim=1)
|
||||
output = torch.abs(output)
|
||||
|
||||
return self.post(output)
|
||||
71
modelscope/models/cv/video_deinterlace/models/enh.py
Normal file
71
modelscope/models/cv/video_deinterlace/models/enh.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_deinterlace.models.archs import (DoubleConv,
|
||||
DownConv,
|
||||
TripleConv,
|
||||
UpCatConv)
|
||||
from modelscope.models.cv.video_deinterlace.models.utils import warp
|
||||
|
||||
|
||||
class DeinterlaceEnh(nn.Module):
|
||||
"""Defines a U-Net video enhancement module
|
||||
|
||||
Arg:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_feat (int): Channel number of base intermediate features. Default: 64.
|
||||
"""
|
||||
|
||||
def __init__(self, num_in_ch=3, num_feat=64):
|
||||
super(DeinterlaceEnh, self).__init__()
|
||||
self.channel = num_in_ch
|
||||
|
||||
# extra convolutions
|
||||
self.inconv2_1 = DoubleConv(num_in_ch * 3, 48)
|
||||
# downsample
|
||||
self.down2_0 = DownConv(48, 80)
|
||||
self.down2_1 = DownConv(80, 144)
|
||||
self.down2_2 = DownConv(144, 256)
|
||||
self.down2_3 = DownConv(256, 448, num_conv=3)
|
||||
# upsample
|
||||
self.up2_3 = UpCatConv(704, 256)
|
||||
self.up2_2 = UpCatConv(400, 144)
|
||||
self.up2_1 = UpCatConv(224, 80)
|
||||
self.up2_0 = UpCatConv(128, 48)
|
||||
# extra convolutions
|
||||
self.outconv2_1 = nn.Conv2d(48, num_in_ch, 3, 1, 1, bias=False)
|
||||
|
||||
self.offset_conv1 = nn.Sequential(
|
||||
nn.Conv2d(num_in_ch * 2, num_feat, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(num_feat, num_feat, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(num_feat, num_in_ch * 2, kernel_size=3, padding=1))
|
||||
self.offset_conv2 = nn.Sequential(
|
||||
nn.Conv2d(num_in_ch * 2, num_feat, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(num_feat, num_feat, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(num_feat, num_in_ch * 2, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, frames):
|
||||
frame1, frame2, frame3 = frames
|
||||
flow1 = self.offset_conv1(torch.cat([frame1, frame2], 1))
|
||||
warp1 = warp(frame1, flow1)
|
||||
flow3 = self.offset_conv2(torch.cat([frame3, frame2], 1))
|
||||
warp3 = warp(frame3, flow3)
|
||||
x2_0 = self.inconv2_1(torch.cat((warp1, frame2, warp3), 1))
|
||||
# downsample
|
||||
x2_1 = self.down2_0(x2_0) # 1/2
|
||||
x2_2 = self.down2_1(x2_1) # 1/4
|
||||
x2_3 = self.down2_2(x2_2) # 1/8
|
||||
x2_4 = self.down2_3(x2_3) # 1/16
|
||||
|
||||
x2_5 = self.up2_3(x2_4, x2_3) # 1/8
|
||||
x2_5 = self.up2_2(x2_5, x2_2) # 1/4
|
||||
x2_5 = self.up2_1(x2_5, x2_1) # 1/2
|
||||
x2_5 = self.up2_0(x2_5, x2_0) # 1
|
||||
out_final = self.outconv2_1(x2_5)
|
||||
return out_final
|
||||
93
modelscope/models/cv/video_deinterlace/models/fre.py
Normal file
93
modelscope/models/cv/video_deinterlace/models/fre.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_deinterlace.models.archs import (DoubleConv,
|
||||
DownConv,
|
||||
TripleConv,
|
||||
UpCatConv)
|
||||
from modelscope.models.cv.video_deinterlace.models.deep_fourier_upsampling import \
|
||||
freup_Periodicpadding
|
||||
|
||||
|
||||
class DeinterlaceFre(nn.Module):
|
||||
|
||||
def __init__(self, num_in_ch=3, num_out_ch=3, ngf=64):
|
||||
"""Defines a video deinterlace module.
|
||||
input a [b,c,h,w] tensor with range [0,1] as frame,
|
||||
it will output a [b,c,h,w] tensor with range [0,1] whitout interlace.
|
||||
|
||||
Args:
|
||||
num_in_ch (int): Channel number of inputs. Default: 3.
|
||||
num_out_ch (int): Channel number of outputs. Default: 3.
|
||||
ngf(int): Channel number of features. Default: 64.
|
||||
"""
|
||||
super(DeinterlaceFre, self).__init__()
|
||||
|
||||
self.inconv = DoubleConv(num_in_ch, 48)
|
||||
self.down_0 = DownConv(48, 80)
|
||||
self.down_1 = DownConv(80, 144)
|
||||
|
||||
self.opfre_0 = freup_Periodicpadding(80)
|
||||
self.opfre_1 = freup_Periodicpadding(144)
|
||||
|
||||
self.conv_up1 = nn.Conv2d(80, ngf, 3, 1, 1)
|
||||
self.conv_up2 = nn.Conv2d(144, 80, 3, 1, 1)
|
||||
|
||||
self.conv_hr = nn.Conv2d(ngf, ngf, 3, 1, 1)
|
||||
self.conv_last = nn.Conv2d(ngf, num_out_ch, 3, 1, 1)
|
||||
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
|
||||
self.enh_inconv = DoubleConv(num_in_ch + num_out_ch, 48)
|
||||
# downsample
|
||||
self.enh_down_0 = DownConv(48, 80)
|
||||
self.enh_down_1 = DownConv(80, 144)
|
||||
self.enh_down_2 = DownConv(144, 256)
|
||||
self.enh_down_3 = DownConv(256, 448, num_conv=3)
|
||||
# upsample
|
||||
self.enh_up_3 = UpCatConv(704, 256)
|
||||
self.enh_up_2 = UpCatConv(400, 144)
|
||||
self.enh_up_1 = UpCatConv(224, 80)
|
||||
self.enh_up_0 = UpCatConv(128, 48)
|
||||
# extra convolutions
|
||||
self.enh_outconv = nn.Conv2d(48, num_out_ch, 3, 1, 1, bias=False)
|
||||
|
||||
def interpolate(self, feat, x2, fn):
|
||||
x1f = fn(feat)
|
||||
x1 = F.interpolate(feat, scale_factor=2, mode='nearest')
|
||||
diffY = x2.size()[2] - x1.size()[2]
|
||||
diffX = x2.size()[3] - x1.size()[3]
|
||||
x1f = F.pad(
|
||||
x1f,
|
||||
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
x1 = F.pad(
|
||||
x1,
|
||||
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
|
||||
return x1 + x1f
|
||||
|
||||
def forward(self, x):
|
||||
x1_0 = self.inconv(x)
|
||||
# downsample
|
||||
x1_1 = self.down_0(x1_0) # 1/2
|
||||
x1_2 = self.down_1(x1_1) # 1/4
|
||||
|
||||
feat = self.lrelu(
|
||||
self.conv_up2(self.interpolate(x1_2, x1_1, self.opfre_1)))
|
||||
feat = self.lrelu(
|
||||
self.conv_up1(self.interpolate(feat, x1_0, self.opfre_0)))
|
||||
x_new = self.conv_last(self.lrelu(self.conv_hr(feat)))
|
||||
|
||||
x2_0 = self.enh_inconv(torch.cat([x_new, x], 1))
|
||||
# downsample
|
||||
x2_1 = self.enh_down_0(x2_0) # 1/2
|
||||
x2_2 = self.enh_down_1(x2_1) # 1/4
|
||||
x2_3 = self.enh_down_2(x2_2) # 1/8
|
||||
x2_4 = self.enh_down_3(x2_3) # 1/16
|
||||
|
||||
x2_5 = self.enh_up_3(x2_4, x2_3) # 1/8
|
||||
x2_5 = self.enh_up_2(x2_5, x2_2) # 1/4
|
||||
x2_5 = self.enh_up_1(x2_5, x2_1) # 1/2
|
||||
x2_5 = self.enh_up_0(x2_5, x2_0) # 1
|
||||
out = self.enh_outconv(x2_5)
|
||||
return out
|
||||
107
modelscope/models/cv/video_deinterlace/models/utils.py
Normal file
107
modelscope/models/cv/video_deinterlace/models/utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
|
||||
|
||||
def warp(im, flow):
|
||||
|
||||
def _repeat(x, n_repeats):
|
||||
rep = torch.ones((1, n_repeats), dtype=torch.int32)
|
||||
x = torch.matmul(x.view(-1, 1).int(), rep)
|
||||
return x.view(-1)
|
||||
|
||||
def _repeat2(x, n_repeats):
|
||||
rep = torch.ones((n_repeats, 1), dtype=torch.int32)
|
||||
x = torch.matmul(rep, x.view(1, -1).int())
|
||||
return x.view(-1)
|
||||
|
||||
def _interpolate(im, x, y):
|
||||
num_batch, channels, height, width = im.shape
|
||||
|
||||
x = x.float()
|
||||
y = y.float()
|
||||
max_y = height - 1
|
||||
max_x = width - 1
|
||||
|
||||
x = _repeat2(torch.arange(0, width),
|
||||
height * num_batch).float().cuda() + x * 64
|
||||
y = _repeat2(_repeat(torch.arange(0, height), width),
|
||||
num_batch).float().cuda() + y * 64
|
||||
|
||||
# do sampling
|
||||
x0 = (torch.floor(x.cpu())).int()
|
||||
x1 = x0 + 1
|
||||
y0 = (torch.floor(y.cpu())).int()
|
||||
y1 = y0 + 1
|
||||
|
||||
x0 = torch.clamp(x0, 0, max_x)
|
||||
x1 = torch.clamp(x1, 0, max_x)
|
||||
y0 = torch.clamp(y0, 0, max_y)
|
||||
y1 = torch.clamp(y1, 0, max_y)
|
||||
dim2 = width
|
||||
dim1 = width * height
|
||||
base = _repeat(torch.arange(num_batch) * dim1, height * width)
|
||||
base_y0 = base + y0 * dim2
|
||||
base_y1 = base + y1 * dim2
|
||||
idx_a = base_y0 + x0
|
||||
idx_b = base_y1 + x0
|
||||
idx_c = base_y0 + x1
|
||||
idx_d = base_y1 + x1
|
||||
|
||||
# use indices to lookup pixels in the flat image and restore
|
||||
im_flat = im.permute(0, 2, 3, 1)
|
||||
im_flat = im_flat.reshape((-1, channels)).float()
|
||||
Ia = torch.gather(
|
||||
im_flat, dim=0, index=torch.unsqueeze(idx_a, 1).long().cuda())
|
||||
Ib = torch.gather(im_flat, 0, torch.unsqueeze(idx_b, 1).long().cuda())
|
||||
Ic = torch.gather(im_flat, 0, torch.unsqueeze(idx_c, 1).long().cuda())
|
||||
Id = torch.gather(im_flat, 0, torch.unsqueeze(idx_d, 1).long().cuda())
|
||||
# and finally calculate interpolated values
|
||||
x0_f = x0.float().cuda()
|
||||
x1_f = x1.float().cuda()
|
||||
y0_f = y0.float().cuda()
|
||||
y1_f = y1.float().cuda()
|
||||
wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
|
||||
wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
|
||||
wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
|
||||
wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
|
||||
output = wa * Ia + wb * Ib + wc * Ic + wd * Id
|
||||
return output
|
||||
|
||||
def _meshgrid(height, width):
|
||||
x_t = torch.matmul(
|
||||
torch.ones((height, 1)),
|
||||
torch.unsqueeze(torch.linspace(-0.1, 0.1, width),
|
||||
1).permute(1, 0)).cuda()
|
||||
y_t = torch.matmul(
|
||||
torch.unsqueeze(torch.linspace(-0.1, 0.1, height), 1),
|
||||
torch.ones((1, width))).cuda()
|
||||
|
||||
x_t_flat = x_t.reshape((1, -1))
|
||||
y_t_flat = y_t.reshape((1, -1))
|
||||
|
||||
ones = torch.ones_like(x_t_flat).cuda()
|
||||
grid = torch.cat((x_t_flat, y_t_flat, ones), 0)
|
||||
|
||||
return grid
|
||||
|
||||
def _warp(x_s, y_s, input_dim):
|
||||
num_batch, num_channels, height, width = input_dim.shape
|
||||
# out_height, out_width = out_size
|
||||
|
||||
x_s_flat = x_s.reshape(-1)
|
||||
y_s_flat = y_s.reshape(-1)
|
||||
|
||||
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat)
|
||||
output = input_transformed.reshape(
|
||||
(num_batch, num_channels, height, width))
|
||||
return output
|
||||
|
||||
n_dims = int(flow.shape[1]) // 2
|
||||
dx = flow[:, :n_dims, :, :]
|
||||
dy = flow[:, n_dims:, :, :]
|
||||
|
||||
output = torch.cat([
|
||||
_warp(dx[:, idx:idx + 1, :, :], dy[:, idx:idx + 1, :, :],
|
||||
im[:, idx:idx + 1, :, :]) for idx in range(im.shape[1])
|
||||
], 1)
|
||||
return output
|
||||
@@ -359,6 +359,7 @@ TASK_OUTPUTS = {
|
||||
# {"output_video": "path_to_rendered_video"}
|
||||
Tasks.video_frame_interpolation: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.video_super_resolution: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.video_deinterlace: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.nerf_recon_acc: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],
|
||||
|
||||
|
||||
@@ -84,6 +84,7 @@ if TYPE_CHECKING:
|
||||
from .image_driving_perception_pipeline import ImageDrivingPerceptionPipeline
|
||||
from .vop_retrieval_pipeline import VopRetrievalPipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
from .video_deinterlace_pipeline import VideoDeinterlacePipeline
|
||||
from .image_matching_pipeline import ImageMatchingPipeline
|
||||
from .video_stabilization_pipeline import VideoStabilizationPipeline
|
||||
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
|
||||
@@ -220,6 +221,7 @@ else:
|
||||
'video_object_segmentation_pipeline': [
|
||||
'VideoObjectSegmentationPipeline'
|
||||
],
|
||||
'video_deinterlace_pipeline': ['VideoDeinterlacePipeline'],
|
||||
'image_matching_pipeline': ['ImageMatchingPipeline'],
|
||||
'video_stabilization_pipeline': ['VideoStabilizationPipeline'],
|
||||
'video_super_resolution_pipeline': ['VideoSuperResolutionPipeline'],
|
||||
|
||||
186
modelscope/pipelines/cv/video_deinterlace_pipeline.py
Normal file
186
modelscope/pipelines/cv/video_deinterlace_pipeline.py
Normal file
@@ -0,0 +1,186 @@
|
||||
# The implementation here is modified based on RealBasicVSR,
|
||||
# originally Apache 2.0 License and publicly avaialbe at
|
||||
# https://github.com/ckkelvinchan/RealBasicVSR/blob/master/inference_realbasicvsr.py
|
||||
import math
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.video_deinterlace.UNet_for_video_deinterlace import \
|
||||
UNetForVideoDeinterlace
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors.cv import VideoReader
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
VIDEO_EXTENSIONS = ('.mp4', '.mov')
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
|
||||
"""Convert torch Tensors into image numpy arrays.
|
||||
After clamping to (min, max), image values will be normalized to [0, 1].
|
||||
For different tensor shapes, this function will have different behaviors:
|
||||
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
|
||||
Use `make_grid` to stitch images in the batch dimension, and then
|
||||
convert it to numpy array.
|
||||
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
|
||||
Directly change to numpy array.
|
||||
Note that the image channel in input tensors should be RGB order. This
|
||||
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
|
||||
order.
|
||||
Args:
|
||||
tensor (Tensor | list[Tensor]): Input tensors.
|
||||
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
|
||||
to uint8 type with range [0, 255]; otherwise, float type with
|
||||
range [0, 1]. Default: ``np.uint8``.
|
||||
min_max (tuple): min and max values for clamp.
|
||||
Returns:
|
||||
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
|
||||
of shape (H x W).
|
||||
"""
|
||||
condition = torch.is_tensor(tensor) or (isinstance(tensor, list) and all(
|
||||
torch.is_tensor(t) for t in tensor))
|
||||
if not condition:
|
||||
raise TypeError(
|
||||
f'tensor or list of tensors expected, got {type(tensor)}')
|
||||
|
||||
if torch.is_tensor(tensor):
|
||||
tensor = [tensor]
|
||||
result = []
|
||||
for _tensor in tensor:
|
||||
# Squeeze two times so that:
|
||||
# 1. (1, 1, h, w) -> (h, w) or
|
||||
# 3. (1, 3, h, w) -> (3, h, w) or
|
||||
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
|
||||
_tensor = _tensor.squeeze(0).squeeze(0)
|
||||
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
|
||||
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
|
||||
n_dim = _tensor.dim()
|
||||
if n_dim == 4:
|
||||
img_np = make_grid(
|
||||
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
|
||||
normalize=False).numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
|
||||
elif n_dim == 3:
|
||||
img_np = _tensor.numpy()
|
||||
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
|
||||
elif n_dim == 2:
|
||||
img_np = _tensor.numpy()
|
||||
else:
|
||||
raise ValueError('Only support 4D, 3D or 2D tensor. '
|
||||
f'But received with dimension: {n_dim}')
|
||||
if out_type == np.uint8:
|
||||
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
|
||||
img_np = (img_np * 255.0).round()
|
||||
img_np = img_np.astype(out_type)
|
||||
result.append(img_np)
|
||||
result = result[0] if len(result) == 1 else result
|
||||
return result
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_deinterlace, module_name=Pipelines.video_deinterlace)
|
||||
class VideoDeinterlacePipeline(Pipeline):
|
||||
|
||||
def __init__(self,
|
||||
model: Union[UNetForVideoDeinterlace, str],
|
||||
preprocessor=None,
|
||||
**kwargs):
|
||||
"""The inference pipeline for all the video deinterlace sub-tasks.
|
||||
|
||||
Args:
|
||||
model (`str` or `Model` or module instance): A model instance or a model local dir
|
||||
or a model id in the model hub.
|
||||
preprocessor (`Preprocessor`, `optional`): A Preprocessor instance.
|
||||
kwargs (dict, `optional`):
|
||||
Extra kwargs passed into the preprocessor's constructor.
|
||||
|
||||
Example:
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> pipeline_ins = pipeline('video-deinterlace',
|
||||
model='damo/cv_unet_video-deinterlace')
|
||||
>>> input = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_deinterlace_test.mp4'
|
||||
>>> print(pipeline_ins(input)[OutputKeys.OUTPUT_VIDEO])
|
||||
"""
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
self.net = self.model.model
|
||||
self.net.to(self._device)
|
||||
self.net.eval()
|
||||
|
||||
logger.info('load video deinterlace model done')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
# input is a video file
|
||||
video_reader = VideoReader(input)
|
||||
inputs = []
|
||||
for frame in video_reader:
|
||||
inputs.append(np.flip(frame, axis=2))
|
||||
fps = video_reader.fps
|
||||
|
||||
for i, img in enumerate(inputs):
|
||||
img = torch.from_numpy(img / 255.).permute(2, 0, 1).float()
|
||||
inputs[i] = img.unsqueeze(0)
|
||||
inputs = torch.stack(inputs, dim=1)
|
||||
return {'video': inputs, 'fps': fps}
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
inputs = input['video'][0]
|
||||
frenet = self.net.frenet
|
||||
enhnet = self.net.enhnet
|
||||
with torch.no_grad():
|
||||
outputs = []
|
||||
frames = []
|
||||
for i in range(0, inputs.size(0)):
|
||||
frames.append(frenet(inputs[i:i + 1, ...].to(self._device)))
|
||||
if i == 0:
|
||||
frames = [frames[-1]] * 2
|
||||
continue
|
||||
outputs.append(enhnet(frames).cpu().unsqueeze(1))
|
||||
frames = frames[1:]
|
||||
|
||||
frames.append(frames[-1])
|
||||
outputs.append(enhnet(frames).cpu().unsqueeze(1))
|
||||
outputs = torch.cat(outputs, dim=1)
|
||||
return {'output': outputs, 'fps': input['fps']}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
output_video_path = kwargs.get('output_video', None)
|
||||
demo_service = kwargs.get('demo_service', False)
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
|
||||
h, w = inputs['output'].shape[-2:]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_writer = cv2.VideoWriter(output_video_path, fourcc,
|
||||
inputs['fps'], (w, h))
|
||||
for i in range(0, inputs['output'].size(1)):
|
||||
img = tensor2img(inputs['output'][:, i, :, :, :])
|
||||
video_writer.write(img.astype(np.uint8))
|
||||
video_writer.release()
|
||||
|
||||
if demo_service:
|
||||
assert os.system(
|
||||
'ffmpeg -version'
|
||||
) == 0, 'ffmpeg is not installed correctly, please refer to https://trac.ffmpeg.org/wiki/CompilationGuide.'
|
||||
output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
|
||||
convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
|
||||
subprocess.call(convert_cmd, shell=True)
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
@@ -113,6 +113,7 @@ class CVTasks(object):
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'video-super-resolution'
|
||||
video_deinterlace = 'video-deinterlace'
|
||||
video_colorization = 'video-colorization'
|
||||
|
||||
# reid and tracking
|
||||
|
||||
61
tests/pipelines/test_video_deinterlace.py
Normal file
61
tests/pipelines/test_video_deinterlace.py
Normal file
@@ -0,0 +1,61 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.cv import VideoDeinterlacePipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VideoDeinterlaceTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.video_deinterlace
|
||||
self.model_id = 'damo/cv_unet_video-deinterlace'
|
||||
self.test_video = 'data/test/videos/video_deinterlace_test.mp4'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
pipeline = VideoDeinterlacePipeline(cache_path)
|
||||
pipeline.group_key = self.task
|
||||
out_video_path = pipeline(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
||||
def test_run_with_model_from_pretrained(self):
|
||||
cache_path = Model.from_pretrained(self.model_id)
|
||||
pipeline = VideoDeinterlacePipeline(cache_path)
|
||||
pipeline.group_key = self.task
|
||||
out_video_path = pipeline(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.video_deinterlace, model=self.model_id)
|
||||
out_video_path = pipeline_ins(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.video_deinterlace)
|
||||
out_video_path = pipeline_ins(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -55,6 +55,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
|
||||
- test_image_deblur_trainer.py
|
||||
- test_image_quality_assessment_mos.py
|
||||
- test_image_restoration.py
|
||||
- test_video_deinterlace.py
|
||||
- test_image_inpainting_sdv2.py
|
||||
|
||||
envs:
|
||||
|
||||
Reference in New Issue
Block a user