add video deinterlace model

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11508847


* add video deinterlace model

* resolve conflict for video deinterlace

* fix CR problems
This commit is contained in:
ljh263654
2023-02-10 02:15:49 +00:00
committed by wenmeng.zwm
parent c0a92403c8
commit 29e47e5030
18 changed files with 816 additions and 5 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:9941ac4a5dd0d9eea5d33ce0009da34d0c93c64ed062479e6c8efb4788e8ef7c
size 522972

View File

@@ -79,6 +79,7 @@ class Models(object):
video_human_matting = 'video-human-matting'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
quadtree_attention_image_matching = 'quadtree-attention-image-matching'
vision_middleware = 'vision-middleware'
video_stabilization = 'video-stabilization'
@@ -329,6 +330,7 @@ class Pipelines(object):
vision_middleware_multi_task = 'vision-middleware-multi-task'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_deinterlace = 'video-deinterlace'
image_matching = 'image-matching'
video_stabilization = 'video-stabilization'
video_super_resolution = 'realbasicvsr-video-super-resolution'
@@ -674,6 +676,8 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.video_frame_interpolation:
(Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
Tasks.video_deinterlace: (Pipelines.video_deinterlace,
'damo/cv_unet_video-deinterlace'),
Tasks.human_wholebody_keypoint:
(Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),

View File

@@ -19,10 +19,10 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
referring_video_object_segmentation,
robust_image_classification, salient_detection,
shop_segmentation, stream_yolo, super_resolution,
video_frame_interpolation, video_object_segmentation,
video_panoptic_segmentation, video_single_object_tracking,
video_stabilization, video_summarization,
video_super_resolution, virual_tryon, vision_middleware,
vop_retrieval)
video_deinterlace, video_frame_interpolation,
video_object_segmentation, video_panoptic_segmentation,
video_single_object_tracking, video_stabilization,
video_summarization, video_super_resolution, virual_tryon,
vision_middleware, vop_retrieval)
# yapf: enable

View File

@@ -0,0 +1,89 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Any, Dict, Union
import torch.cuda
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from modelscope.metainfo import Models
from modelscope.models.base import Tensor
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
from modelscope.models.cv.video_deinterlace.deinterlace_arch import \
DeinterlaceNet
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['UNetForVideoDeinterlace']
def convert(param):
return {
k.replace('module.', ''): v
for k, v in param.items() if 'module.' in k
}
@MODELS.register_module(
Tasks.video_deinterlace, module_name=Models.video_deinterlace)
class UNetForVideoDeinterlace(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the video deinterlace model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model_dir = model_dir
frenet_path = os.path.join(model_dir, 'deinterlace_fre.pth')
enhnet_path = os.path.join(model_dir, 'deinterlace_mf.pth')
self.model = DeinterlaceNet()
self._load_pretrained(frenet_path, enhnet_path)
def _load_pretrained(self, frenet_path, enhnet_path):
state_dict_frenet = torch.load(frenet_path, map_location=self._device)
state_dict_enhnet = torch.load(enhnet_path, map_location=self._device)
self.model.frenet.load_state_dict(state_dict_frenet, strict=True)
self.model.enhnet.load_state_dict(state_dict_enhnet, strict=True)
logger.info('load model done.')
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
return {'output': self.model(input)}
def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
preds = self.model(input)
del input
torch.cuda.empty_cache()
return {'pred': preds, 'target': target}
def forward(self, inputs: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model
Args:
inputs (Tensor): the preprocessed data
Returns:
Dict[str, Tensor]: results
"""
if 'target' in inputs:
return self._evaluate_postprocess(**inputs)
else:
return self._inference_forward(**inputs)

View File

@@ -0,0 +1,21 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .deinterlace_arch import \
DeinterlaceNet
else:
_import_structure = {'deinterlace_arch': ['DeinterlaceNet']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,27 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch.nn as nn
from modelscope.models.cv.video_deinterlace.models.enh import DeinterlaceEnh
from modelscope.models.cv.video_deinterlace.models.fre import DeinterlaceFre
class DeinterlaceNet(nn.Module):
def __init__(self):
super(DeinterlaceNet, self).__init__()
self.frenet = DeinterlaceFre()
self.enhnet = DeinterlaceEnh()
def forward(self, frames):
self.frenet.eval()
self.enhnet.eval()
with torch.no_grad():
frame1, frame2, frame3 = frames
F1_out = self.frenet(frame1)
F2_out = self.frenet(frame2)
F3_out = self.frenet(frame3)
out = self.enhnet([F1_out, F2_out, F3_out])
return out

View File

@@ -0,0 +1,97 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
def forward(self, x):
return self.double_conv(x)
class TripleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 3"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.triple_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(mid_channels, mid_channels, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True))
def forward(self, x):
return self.triple_conv(x)
class DownConv(nn.Module):
"""Downscaling with avgpool then double/triple conv"""
def __init__(self, in_channels, out_channels, num_conv=2):
super().__init__()
if num_conv == 2:
self.pool_conv = nn.Sequential(
nn.AvgPool2d(2), DoubleConv(in_channels, out_channels))
else:
self.pool_conv = nn.Sequential(
nn.AvgPool2d(2), TripleConv(in_channels, out_channels))
def forward(self, x):
return self.pool_conv(x)
class UpCatConv(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=False):
super().__init__()
if bilinear:
self.up = nn.Upsample(
scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels)
else:
self.up = nn.Upsample(
scale_factor=2, mode='nearest', align_corners=None)
self.conv = DoubleConv(in_channels, out_channels)
self.subpixel = nn.PixelShuffle(2)
def interpolate(self, x):
tensor_temp = x
for i in range(3):
tensor_temp = torch.cat((tensor_temp, x), 1)
x = tensor_temp
x = self.subpixel(x)
return x
def forward(self, x1, x2):
x1 = self.interpolate(x1)
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(
x1,
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)

View File

@@ -0,0 +1,47 @@
# The implementation is adopted from Deep Fourier Upsampling,
# made publicly available at https://github.com/manman1995/Deep-Fourier-Upsampling
import numpy as np
import torch
import torch.fft
import torch.nn as nn
import torch.nn.functional as F
class freup_Periodicpadding(nn.Module):
def __init__(self, channels):
super(freup_Periodicpadding, self).__init__()
self.amp_fuse = nn.Sequential(
nn.Conv2d(channels, channels, 1, 1, 0),
nn.LeakyReLU(0.1, inplace=False),
nn.Conv2d(channels, channels, 1, 1, 0))
self.pha_fuse = nn.Sequential(
nn.Conv2d(channels, channels, 1, 1, 0),
nn.LeakyReLU(0.1, inplace=False),
nn.Conv2d(channels, channels, 1, 1, 0))
self.post = nn.Conv2d(channels, channels, 1, 1, 0)
def forward(self, x):
N, C, H, W = x.shape
fft_x = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1)
mag_x = torch.abs(fft_x)
pha_x = torch.angle(fft_x).detach()
Mag = self.amp_fuse(mag_x)
Pha = self.pha_fuse(pha_x)
amp_fuse = Mag.repeat(1, 1, 2, 2)
pha_fuse = Pha.repeat(1, 1, 2, 2)
real = amp_fuse * torch.cos(pha_fuse)
imag = amp_fuse * torch.sin(pha_fuse)
out = torch.complex(real, imag)
output = torch.fft.ifft(torch.fft.ifft(out, dim=0), dim=1)
output = torch.abs(output)
return self.post(output)

View File

@@ -0,0 +1,71 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.video_deinterlace.models.archs import (DoubleConv,
DownConv,
TripleConv,
UpCatConv)
from modelscope.models.cv.video_deinterlace.models.utils import warp
class DeinterlaceEnh(nn.Module):
"""Defines a U-Net video enhancement module
Arg:
num_in_ch (int): Channel number of inputs. Default: 3.
num_feat (int): Channel number of base intermediate features. Default: 64.
"""
def __init__(self, num_in_ch=3, num_feat=64):
super(DeinterlaceEnh, self).__init__()
self.channel = num_in_ch
# extra convolutions
self.inconv2_1 = DoubleConv(num_in_ch * 3, 48)
# downsample
self.down2_0 = DownConv(48, 80)
self.down2_1 = DownConv(80, 144)
self.down2_2 = DownConv(144, 256)
self.down2_3 = DownConv(256, 448, num_conv=3)
# upsample
self.up2_3 = UpCatConv(704, 256)
self.up2_2 = UpCatConv(400, 144)
self.up2_1 = UpCatConv(224, 80)
self.up2_0 = UpCatConv(128, 48)
# extra convolutions
self.outconv2_1 = nn.Conv2d(48, num_in_ch, 3, 1, 1, bias=False)
self.offset_conv1 = nn.Sequential(
nn.Conv2d(num_in_ch * 2, num_feat, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(num_feat, num_feat, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(num_feat, num_in_ch * 2, kernel_size=3, padding=1))
self.offset_conv2 = nn.Sequential(
nn.Conv2d(num_in_ch * 2, num_feat, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(num_feat, num_feat, kernel_size=3, padding=1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(num_feat, num_in_ch * 2, kernel_size=3, padding=1))
def forward(self, frames):
frame1, frame2, frame3 = frames
flow1 = self.offset_conv1(torch.cat([frame1, frame2], 1))
warp1 = warp(frame1, flow1)
flow3 = self.offset_conv2(torch.cat([frame3, frame2], 1))
warp3 = warp(frame3, flow3)
x2_0 = self.inconv2_1(torch.cat((warp1, frame2, warp3), 1))
# downsample
x2_1 = self.down2_0(x2_0) # 1/2
x2_2 = self.down2_1(x2_1) # 1/4
x2_3 = self.down2_2(x2_2) # 1/8
x2_4 = self.down2_3(x2_3) # 1/16
x2_5 = self.up2_3(x2_4, x2_3) # 1/8
x2_5 = self.up2_2(x2_5, x2_2) # 1/4
x2_5 = self.up2_1(x2_5, x2_1) # 1/2
x2_5 = self.up2_0(x2_5, x2_0) # 1
out_final = self.outconv2_1(x2_5)
return out_final

View File

@@ -0,0 +1,93 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.video_deinterlace.models.archs import (DoubleConv,
DownConv,
TripleConv,
UpCatConv)
from modelscope.models.cv.video_deinterlace.models.deep_fourier_upsampling import \
freup_Periodicpadding
class DeinterlaceFre(nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=3, ngf=64):
"""Defines a video deinterlace module.
input a [b,c,h,w] tensor with range [0,1] as frame,
it will output a [b,c,h,w] tensor with range [0,1] whitout interlace.
Args:
num_in_ch (int): Channel number of inputs. Default: 3.
num_out_ch (int): Channel number of outputs. Default: 3.
ngf(int): Channel number of features. Default: 64.
"""
super(DeinterlaceFre, self).__init__()
self.inconv = DoubleConv(num_in_ch, 48)
self.down_0 = DownConv(48, 80)
self.down_1 = DownConv(80, 144)
self.opfre_0 = freup_Periodicpadding(80)
self.opfre_1 = freup_Periodicpadding(144)
self.conv_up1 = nn.Conv2d(80, ngf, 3, 1, 1)
self.conv_up2 = nn.Conv2d(144, 80, 3, 1, 1)
self.conv_hr = nn.Conv2d(ngf, ngf, 3, 1, 1)
self.conv_last = nn.Conv2d(ngf, num_out_ch, 3, 1, 1)
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.enh_inconv = DoubleConv(num_in_ch + num_out_ch, 48)
# downsample
self.enh_down_0 = DownConv(48, 80)
self.enh_down_1 = DownConv(80, 144)
self.enh_down_2 = DownConv(144, 256)
self.enh_down_3 = DownConv(256, 448, num_conv=3)
# upsample
self.enh_up_3 = UpCatConv(704, 256)
self.enh_up_2 = UpCatConv(400, 144)
self.enh_up_1 = UpCatConv(224, 80)
self.enh_up_0 = UpCatConv(128, 48)
# extra convolutions
self.enh_outconv = nn.Conv2d(48, num_out_ch, 3, 1, 1, bias=False)
def interpolate(self, feat, x2, fn):
x1f = fn(feat)
x1 = F.interpolate(feat, scale_factor=2, mode='nearest')
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1f = F.pad(
x1f,
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
x1 = F.pad(
x1,
[diffX // 2, diffX - diffX // 2, diffY // 2, diffY - diffY // 2])
return x1 + x1f
def forward(self, x):
x1_0 = self.inconv(x)
# downsample
x1_1 = self.down_0(x1_0) # 1/2
x1_2 = self.down_1(x1_1) # 1/4
feat = self.lrelu(
self.conv_up2(self.interpolate(x1_2, x1_1, self.opfre_1)))
feat = self.lrelu(
self.conv_up1(self.interpolate(feat, x1_0, self.opfre_0)))
x_new = self.conv_last(self.lrelu(self.conv_hr(feat)))
x2_0 = self.enh_inconv(torch.cat([x_new, x], 1))
# downsample
x2_1 = self.enh_down_0(x2_0) # 1/2
x2_2 = self.enh_down_1(x2_1) # 1/4
x2_3 = self.enh_down_2(x2_2) # 1/8
x2_4 = self.enh_down_3(x2_3) # 1/16
x2_5 = self.enh_up_3(x2_4, x2_3) # 1/8
x2_5 = self.enh_up_2(x2_5, x2_2) # 1/4
x2_5 = self.enh_up_1(x2_5, x2_1) # 1/2
x2_5 = self.enh_up_0(x2_5, x2_0) # 1
out = self.enh_outconv(x2_5)
return out

View File

@@ -0,0 +1,107 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
def warp(im, flow):
def _repeat(x, n_repeats):
rep = torch.ones((1, n_repeats), dtype=torch.int32)
x = torch.matmul(x.view(-1, 1).int(), rep)
return x.view(-1)
def _repeat2(x, n_repeats):
rep = torch.ones((n_repeats, 1), dtype=torch.int32)
x = torch.matmul(rep, x.view(1, -1).int())
return x.view(-1)
def _interpolate(im, x, y):
num_batch, channels, height, width = im.shape
x = x.float()
y = y.float()
max_y = height - 1
max_x = width - 1
x = _repeat2(torch.arange(0, width),
height * num_batch).float().cuda() + x * 64
y = _repeat2(_repeat(torch.arange(0, height), width),
num_batch).float().cuda() + y * 64
# do sampling
x0 = (torch.floor(x.cpu())).int()
x1 = x0 + 1
y0 = (torch.floor(y.cpu())).int()
y1 = y0 + 1
x0 = torch.clamp(x0, 0, max_x)
x1 = torch.clamp(x1, 0, max_x)
y0 = torch.clamp(y0, 0, max_y)
y1 = torch.clamp(y1, 0, max_y)
dim2 = width
dim1 = width * height
base = _repeat(torch.arange(num_batch) * dim1, height * width)
base_y0 = base + y0 * dim2
base_y1 = base + y1 * dim2
idx_a = base_y0 + x0
idx_b = base_y1 + x0
idx_c = base_y0 + x1
idx_d = base_y1 + x1
# use indices to lookup pixels in the flat image and restore
im_flat = im.permute(0, 2, 3, 1)
im_flat = im_flat.reshape((-1, channels)).float()
Ia = torch.gather(
im_flat, dim=0, index=torch.unsqueeze(idx_a, 1).long().cuda())
Ib = torch.gather(im_flat, 0, torch.unsqueeze(idx_b, 1).long().cuda())
Ic = torch.gather(im_flat, 0, torch.unsqueeze(idx_c, 1).long().cuda())
Id = torch.gather(im_flat, 0, torch.unsqueeze(idx_d, 1).long().cuda())
# and finally calculate interpolated values
x0_f = x0.float().cuda()
x1_f = x1.float().cuda()
y0_f = y0.float().cuda()
y1_f = y1.float().cuda()
wa = torch.unsqueeze(((x1_f - x) * (y1_f - y)), 1)
wb = torch.unsqueeze(((x1_f - x) * (y - y0_f)), 1)
wc = torch.unsqueeze(((x - x0_f) * (y1_f - y)), 1)
wd = torch.unsqueeze(((x - x0_f) * (y - y0_f)), 1)
output = wa * Ia + wb * Ib + wc * Ic + wd * Id
return output
def _meshgrid(height, width):
x_t = torch.matmul(
torch.ones((height, 1)),
torch.unsqueeze(torch.linspace(-0.1, 0.1, width),
1).permute(1, 0)).cuda()
y_t = torch.matmul(
torch.unsqueeze(torch.linspace(-0.1, 0.1, height), 1),
torch.ones((1, width))).cuda()
x_t_flat = x_t.reshape((1, -1))
y_t_flat = y_t.reshape((1, -1))
ones = torch.ones_like(x_t_flat).cuda()
grid = torch.cat((x_t_flat, y_t_flat, ones), 0)
return grid
def _warp(x_s, y_s, input_dim):
num_batch, num_channels, height, width = input_dim.shape
# out_height, out_width = out_size
x_s_flat = x_s.reshape(-1)
y_s_flat = y_s.reshape(-1)
input_transformed = _interpolate(input_dim, x_s_flat, y_s_flat)
output = input_transformed.reshape(
(num_batch, num_channels, height, width))
return output
n_dims = int(flow.shape[1]) // 2
dx = flow[:, :n_dims, :, :]
dy = flow[:, n_dims:, :, :]
output = torch.cat([
_warp(dx[:, idx:idx + 1, :, :], dy[:, idx:idx + 1, :, :],
im[:, idx:idx + 1, :, :]) for idx in range(im.shape[1])
], 1)
return output

View File

@@ -359,6 +359,7 @@ TASK_OUTPUTS = {
# {"output_video": "path_to_rendered_video"}
Tasks.video_frame_interpolation: [OutputKeys.OUTPUT_VIDEO],
Tasks.video_super_resolution: [OutputKeys.OUTPUT_VIDEO],
Tasks.video_deinterlace: [OutputKeys.OUTPUT_VIDEO],
Tasks.nerf_recon_acc: [OutputKeys.OUTPUT_VIDEO],
Tasks.video_colorization: [OutputKeys.OUTPUT_VIDEO],

View File

@@ -84,6 +84,7 @@ if TYPE_CHECKING:
from .image_driving_perception_pipeline import ImageDrivingPerceptionPipeline
from .vop_retrieval_pipeline import VopRetrievalPipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
from .video_deinterlace_pipeline import VideoDeinterlacePipeline
from .image_matching_pipeline import ImageMatchingPipeline
from .video_stabilization_pipeline import VideoStabilizationPipeline
from .video_super_resolution_pipeline import VideoSuperResolutionPipeline
@@ -220,6 +221,7 @@ else:
'video_object_segmentation_pipeline': [
'VideoObjectSegmentationPipeline'
],
'video_deinterlace_pipeline': ['VideoDeinterlacePipeline'],
'image_matching_pipeline': ['ImageMatchingPipeline'],
'video_stabilization_pipeline': ['VideoStabilizationPipeline'],
'video_super_resolution_pipeline': ['VideoSuperResolutionPipeline'],

View File

@@ -0,0 +1,186 @@
# The implementation here is modified based on RealBasicVSR,
# originally Apache 2.0 License and publicly avaialbe at
# https://github.com/ckkelvinchan/RealBasicVSR/blob/master/inference_realbasicvsr.py
import math
import os
import subprocess
import tempfile
from typing import Any, Dict, Optional, Union
import cv2
import numpy as np
import torch
from torchvision.utils import make_grid
from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_deinterlace.UNet_for_video_deinterlace import \
UNetForVideoDeinterlace
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors.cv import VideoReader
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
VIDEO_EXTENSIONS = ('.mp4', '.mov')
logger = get_logger()
def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
"""Convert torch Tensors into image numpy arrays.
After clamping to (min, max), image values will be normalized to [0, 1].
For different tensor shapes, this function will have different behaviors:
1. 4D mini-batch Tensor of shape (N x 3/1 x H x W):
Use `make_grid` to stitch images in the batch dimension, and then
convert it to numpy array.
2. 3D Tensor of shape (3/1 x H x W) and 2D Tensor of shape (H x W):
Directly change to numpy array.
Note that the image channel in input tensors should be RGB order. This
function will convert it to cv2 convention, i.e., (H x W x C) with BGR
order.
Args:
tensor (Tensor | list[Tensor]): Input tensors.
out_type (numpy type): Output types. If ``np.uint8``, transform outputs
to uint8 type with range [0, 255]; otherwise, float type with
range [0, 1]. Default: ``np.uint8``.
min_max (tuple): min and max values for clamp.
Returns:
(Tensor | list[Tensor]): 3D ndarray of shape (H x W x C) or 2D ndarray
of shape (H x W).
"""
condition = torch.is_tensor(tensor) or (isinstance(tensor, list) and all(
torch.is_tensor(t) for t in tensor))
if not condition:
raise TypeError(
f'tensor or list of tensors expected, got {type(tensor)}')
if torch.is_tensor(tensor):
tensor = [tensor]
result = []
for _tensor in tensor:
# Squeeze two times so that:
# 1. (1, 1, h, w) -> (h, w) or
# 3. (1, 3, h, w) -> (3, h, w) or
# 2. (n>1, 3/1, h, w) -> (n>1, 3/1, h, w)
_tensor = _tensor.squeeze(0).squeeze(0)
_tensor = _tensor.float().detach().cpu().clamp_(*min_max)
_tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
n_dim = _tensor.dim()
if n_dim == 4:
img_np = make_grid(
_tensor, nrow=int(math.sqrt(_tensor.size(0))),
normalize=False).numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 3:
img_np = _tensor.numpy()
img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
elif n_dim == 2:
img_np = _tensor.numpy()
else:
raise ValueError('Only support 4D, 3D or 2D tensor. '
f'But received with dimension: {n_dim}')
if out_type == np.uint8:
# Unlike MATLAB, numpy.unit8() WILL NOT round by default.
img_np = (img_np * 255.0).round()
img_np = img_np.astype(out_type)
result.append(img_np)
result = result[0] if len(result) == 1 else result
return result
@PIPELINES.register_module(
Tasks.video_deinterlace, module_name=Pipelines.video_deinterlace)
class VideoDeinterlacePipeline(Pipeline):
def __init__(self,
model: Union[UNetForVideoDeinterlace, str],
preprocessor=None,
**kwargs):
"""The inference pipeline for all the video deinterlace sub-tasks.
Args:
model (`str` or `Model` or module instance): A model instance or a model local dir
or a model id in the model hub.
preprocessor (`Preprocessor`, `optional`): A Preprocessor instance.
kwargs (dict, `optional`):
Extra kwargs passed into the preprocessor's constructor.
Example:
>>> from modelscope.pipelines import pipeline
>>> pipeline_ins = pipeline('video-deinterlace',
model='damo/cv_unet_video-deinterlace')
>>> input = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_deinterlace_test.mp4'
>>> print(pipeline_ins(input)[OutputKeys.OUTPUT_VIDEO])
"""
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.net = self.model.model
self.net.to(self._device)
self.net.eval()
logger.info('load video deinterlace model done')
def preprocess(self, input: Input) -> Dict[str, Any]:
# input is a video file
video_reader = VideoReader(input)
inputs = []
for frame in video_reader:
inputs.append(np.flip(frame, axis=2))
fps = video_reader.fps
for i, img in enumerate(inputs):
img = torch.from_numpy(img / 255.).permute(2, 0, 1).float()
inputs[i] = img.unsqueeze(0)
inputs = torch.stack(inputs, dim=1)
return {'video': inputs, 'fps': fps}
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
inputs = input['video'][0]
frenet = self.net.frenet
enhnet = self.net.enhnet
with torch.no_grad():
outputs = []
frames = []
for i in range(0, inputs.size(0)):
frames.append(frenet(inputs[i:i + 1, ...].to(self._device)))
if i == 0:
frames = [frames[-1]] * 2
continue
outputs.append(enhnet(frames).cpu().unsqueeze(1))
frames = frames[1:]
frames.append(frames[-1])
outputs.append(enhnet(frames).cpu().unsqueeze(1))
outputs = torch.cat(outputs, dim=1)
return {'output': outputs, 'fps': input['fps']}
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
output_video_path = kwargs.get('output_video', None)
demo_service = kwargs.get('demo_service', False)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
h, w = inputs['output'].shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_video_path, fourcc,
inputs['fps'], (w, h))
for i in range(0, inputs['output'].size(1)):
img = tensor2img(inputs['output'][:, i, :, :, :])
video_writer.write(img.astype(np.uint8))
video_writer.release()
if demo_service:
assert os.system(
'ffmpeg -version'
) == 0, 'ffmpeg is not installed correctly, please refer to https://trac.ffmpeg.org/wiki/CompilationGuide.'
output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
subprocess.call(convert_cmd, shell=True)
return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}

View File

@@ -113,6 +113,7 @@ class CVTasks(object):
video_frame_interpolation = 'video-frame-interpolation'
video_stabilization = 'video-stabilization'
video_super_resolution = 'video-super-resolution'
video_deinterlace = 'video-deinterlace'
video_colorization = 'video-colorization'
# reid and tracking

View File

@@ -0,0 +1,61 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.cv import VideoDeinterlacePipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VideoDeinterlaceTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.video_deinterlace
self.model_id = 'damo/cv_unet_video-deinterlace'
self.test_video = 'data/test/videos/video_deinterlace_test.mp4'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
pipeline = VideoDeinterlacePipeline(cache_path)
pipeline.group_key = self.task
out_video_path = pipeline(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_model_from_pretrained(self):
cache_path = Model.from_pretrained(self.model_id)
pipeline = VideoDeinterlacePipeline(cache_path)
pipeline.group_key = self.task
out_video_path = pipeline(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipeline_ins = pipeline(
task=Tasks.video_deinterlace, model=self.model_id)
out_video_path = pipeline_ins(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.video_deinterlace)
out_video_path = pipeline_ins(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
if __name__ == '__main__':
unittest.main()

View File

@@ -55,6 +55,7 @@ isolated: # test cases that may require excessive anmount of GPU memory or run
- test_image_deblur_trainer.py
- test_image_quality_assessment_mos.py
- test_image_restoration.py
- test_video_deinterlace.py
- test_image_inpainting_sdv2.py
envs: