diff --git a/data/test/videos/video_frame_interpolation_test.mp4 b/data/test/videos/video_frame_interpolation_test.mp4 new file mode 100644 index 00000000..4085a88f --- /dev/null +++ b/data/test/videos/video_frame_interpolation_test.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e97ff88d0af12f7dd3ef04ce50b87b51ffbb9a57dce81d2d518df4abd2fdb826 +size 3231793 diff --git a/modelscope/metainfo.py b/modelscope/metainfo.py index 4ad383bd..209de8e5 100644 --- a/modelscope/metainfo.py +++ b/modelscope/metainfo.py @@ -63,6 +63,7 @@ class Models(object): image_body_reshaping = 'image-body-reshaping' image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' + video_frame_interpolation = 'video-frame-interpolation' video_object_segmentation = 'video-object-segmentation' video_stabilization = 'video-stabilization' real_basicvsr = 'real-basicvsr' @@ -270,6 +271,7 @@ class Pipelines(object): referring_video_object_segmentation = 'referring-video-object-segmentation' image_skychange = 'image-skychange' video_human_matting = 'video-human-matting' + video_frame_interpolation = 'video-frame-interpolation' video_object_segmentation = 'video-object-segmentation' video_stabilization = 'video-stabilization' video_super_resolution = 'realbasicvsr-video-super-resolution' @@ -507,6 +509,8 @@ class Metrics(object): # metrics for image denoise task image_denoise_metric = 'image-denoise-metric' + # metrics for video frame-interpolation task + video_frame_interpolation_metric = 'video-frame-interpolation-metric' # metrics for real-world video super-resolution task video_super_resolution_metric = 'video-super-resolution-metric' diff --git a/modelscope/metrics/__init__.py b/modelscope/metrics/__init__.py index 7192ed5a..f814cf4d 100644 --- a/modelscope/metrics/__init__.py +++ b/modelscope/metrics/__init__.py @@ -21,6 +21,7 @@ if TYPE_CHECKING: from .bleu_metric import BleuMetric from .image_inpainting_metric import ImageInpaintingMetric from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric + from .video_frame_interpolation_metric import VideoFrameInterpolationMetric from .video_stabilization_metric import VideoStabilizationMetric from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric from .ppl_metric import PplMetric @@ -46,6 +47,7 @@ else: 'bleu_metric': ['BleuMetric'], 'referring_video_object_segmentation_metric': ['ReferringVideoObjectSegmentationMetric'], + 'video_frame_interpolation_metric': ['VideoFrameInterpolationMetric'], 'video_stabilization_metric': ['VideoStabilizationMetric'], 'ppl_metric': ['PplMetric'], } diff --git a/modelscope/metrics/builder.py b/modelscope/metrics/builder.py index f683bc36..97979d20 100644 --- a/modelscope/metrics/builder.py +++ b/modelscope/metrics/builder.py @@ -16,6 +16,7 @@ class MetricKeys(object): RECALL = 'recall' PSNR = 'psnr' SSIM = 'ssim' + LPIPS = 'lpips' NIQE = 'niqe' AVERAGE_LOSS = 'avg_loss' FScore = 'fscore' @@ -53,6 +54,8 @@ task_default_metrics = { Tasks.image_inpainting: [Metrics.image_inpainting_metric], Tasks.referring_video_object_segmentation: [Metrics.referring_video_object_segmentation_metric], + Tasks.video_frame_interpolation: + [Metrics.video_frame_interpolation_metric], Tasks.video_stabilization: [Metrics.video_stabilization_metric], } diff --git a/modelscope/metrics/video_frame_interpolation_metric.py b/modelscope/metrics/video_frame_interpolation_metric.py new file mode 100644 index 00000000..d2ecf9fb --- /dev/null +++ b/modelscope/metrics/video_frame_interpolation_metric.py @@ -0,0 +1,172 @@ +# ------------------------------------------------------------------------ +# Copyright (c) Alibaba, Inc. and its affiliates. +# ------------------------------------------------------------------------ +import math +from math import exp +from typing import Dict + +import lpips +import numpy as np +import torch +import torch.nn.functional as F + +from modelscope.metainfo import Metrics +from modelscope.metrics.base import Metric +from modelscope.metrics.builder import METRICS, MetricKeys +from modelscope.utils.registry import default_group + + +@METRICS.register_module( + group_key=default_group, + module_name=Metrics.video_frame_interpolation_metric) +class VideoFrameInterpolationMetric(Metric): + """The metric computation class for video frame interpolation, + which will return PSNR, SSIM and LPIPS. + """ + pred_name = 'pred' + label_name = 'target' + + def __init__(self): + super(VideoFrameInterpolationMetric, self).__init__() + self.preds = [] + self.labels = [] + self.loss_fn_alex = lpips.LPIPS(net='alex').cuda() + + def add(self, outputs: Dict, inputs: Dict): + ground_truths = outputs[VideoFrameInterpolationMetric.label_name] + eval_results = outputs[VideoFrameInterpolationMetric.pred_name] + self.preds.append(eval_results) + self.labels.append(ground_truths) + + def evaluate(self): + psnr_list, ssim_list, lpips_list = [], [], [] + with torch.no_grad(): + for (pred, label) in zip(self.preds, self.labels): + # norm to 0-1 + height, width = label.size(2), label.size(3) + pred = pred[:, :, 0:height, 0:width] + + psnr_list.append(calculate_psnr(label, pred)) + ssim_list.append(calculate_ssim(label, pred)) + lpips_list.append( + calculate_lpips(label, pred, self.loss_fn_alex)) + + return { + MetricKeys.PSNR: np.mean(psnr_list), + MetricKeys.SSIM: np.mean(ssim_list), + MetricKeys.LPIPS: np.mean(lpips_list) + } + + +def gaussian(window_size, sigma): + gauss = torch.Tensor([ + exp(-(x - window_size // 2)**2 / float(2 * sigma**2)) + for x in range(window_size) + ]) + return gauss / gauss.sum() + + +def create_window_3d(window_size, channel=1, device=None): + _1D_window = gaussian(window_size, 1.5).unsqueeze(1) + _2D_window = _1D_window.mm(_1D_window.t()) + _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) + window = _3D_window.expand(1, channel, window_size, window_size, + window_size).contiguous().to(device) + return window + + +def calculate_psnr(img1, img2): + psnr = -10 * math.log10( + torch.mean((img1[0] - img2[0]) * (img1[0] - img2[0])).cpu().data) + return psnr + + +def calculate_ssim(img1, + img2, + window_size=11, + window=None, + size_average=True, + full=False, + val_range=None): + # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). + if val_range is None: + if torch.max(img1) > 128: + max_val = 255 + else: + max_val = 1 + + if torch.min(img1) < -0.5: + min_val = -1 + else: + min_val = 0 + L = max_val - min_val + else: + L = val_range + + padd = 0 + (_, _, height, width) = img1.size() + if window is None: + real_size = min(window_size, height, width) + window = create_window_3d( + real_size, channel=1, device=img1.device).to(img1.device) + # Channel is set to 1 since we consider color images as volumetric images + + img1 = img1.unsqueeze(1) + img2 = img2.unsqueeze(1) + + mu1 = F.conv3d( + F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'), + window, + padding=padd, + groups=1) + mu2 = F.conv3d( + F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'), + window, + padding=padd, + groups=1) + + mu1_sq = mu1.pow(2) + mu2_sq = mu2.pow(2) + mu1_mu2 = mu1 * mu2 + + sigma1_sq = F.conv3d( + F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'), + window, + padding=padd, + groups=1) - mu1_sq + sigma2_sq = F.conv3d( + F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), + window, + padding=padd, + groups=1) - mu2_sq + sigma12 = F.conv3d( + F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'), + window, + padding=padd, + groups=1) - mu1_mu2 + + C1 = (0.01 * L)**2 + C2 = (0.03 * L)**2 + + v1 = 2.0 * sigma12 + C2 + v2 = sigma1_sq + sigma2_sq + C2 + cs = torch.mean(v1 / v2) # contrast sensitivity + + ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) + + if size_average: + ret = ssim_map.mean() + else: + ret = ssim_map.mean(1).mean(1).mean(1) + + if full: + return ret, cs + return ret.cpu() + + +def calculate_lpips(img1, img2, loss_fn_alex): + img1 = img1 * 2 - 1 + img2 = img2 * 2 - 1 + + d = loss_fn_alex(img1, img2) + return d.cpu().item() diff --git a/modelscope/models/cv/__init__.py b/modelscope/models/cv/__init__.py index a86420e8..e5fa9818 100644 --- a/modelscope/models/cv/__init__.py +++ b/modelscope/models/cv/__init__.py @@ -15,8 +15,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints, pointcloud_sceneflow_estimation, product_retrieval_embedding, realtime_object_detection, referring_video_object_segmentation, salient_detection, shop_segmentation, super_resolution, - video_object_segmentation, video_single_object_tracking, - video_stabilization, video_summarization, - video_super_resolution, virual_tryon) + video_frame_interpolation, video_object_segmentation, + video_single_object_tracking, video_stabilization, + video_summarization, video_super_resolution, virual_tryon) # yapf: enable diff --git a/modelscope/models/cv/video_frame_interpolation/VFINet_arch.py b/modelscope/models/cv/video_frame_interpolation/VFINet_arch.py new file mode 100644 index 00000000..33486cf9 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/VFINet_arch.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import torch +import torch.nn as nn + +from modelscope.models.cv.video_frame_interpolation.flow_model.raft import RAFT +from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \ + IFNet +from modelscope.models.cv.video_frame_interpolation.interp_model.refinenet_arch import ( + InterpNet, InterpNetDs) + + +class VFINet(nn.Module): + + def __init__(self, args, Ds_flag=False): + super(VFINet, self).__init__() + self.flownet = RAFT(args) + self.internet = InterpNet() + if Ds_flag: + self.internet_Ds = InterpNetDs() + + def img_trans(self, img_tensor): # in format of RGB + img_tensor = img_tensor / 255.0 + mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, + 1).type_as(img_tensor) + img_tensor -= mean + return img_tensor + + def add_mean(self, x): + mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x) + return x + mean + + def forward(self, imgs, timestep=0.5): + self.flownet.eval() + self.internet.eval() + with torch.no_grad(): + img0 = imgs[:, :3] + img1 = imgs[:, 3:6] + img2 = imgs[:, 6:9] + img3 = imgs[:, 9:12] + + _, F10_up = self.flownet(img1, img0, iters=12, test_mode=True) + _, F12_up = self.flownet(img1, img2, iters=12, test_mode=True) + _, F21_up = self.flownet(img2, img1, iters=12, test_mode=True) + _, F23_up = self.flownet(img2, img3, iters=12, test_mode=True) + + img1 = self.img_trans(img1.clone()) + img2 = self.img_trans(img2.clone()) + + It_warp = self.internet( + img1, img2, F10_up, F12_up, F21_up, F23_up, timestep=timestep) + It_warp = self.add_mean(It_warp) + + return It_warp diff --git a/modelscope/models/cv/video_frame_interpolation/VFINet_for_video_frame_interpolation.py b/modelscope/models/cv/video_frame_interpolation/VFINet_for_video_frame_interpolation.py new file mode 100644 index 00000000..a7ea00e1 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/VFINet_for_video_frame_interpolation.py @@ -0,0 +1,98 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os +from copy import deepcopy +from typing import Any, Dict, Union + +import torch.cuda +import torch.nn.functional as F +from torch.nn.parallel import DataParallel, DistributedDataParallel + +from modelscope.metainfo import Models +from modelscope.models.base import Tensor +from modelscope.models.base.base_torch_model import TorchModel +from modelscope.models.builder import MODELS +# from modelscope.models.cv.video_super_resolution.common import charbonnier_loss +from modelscope.models.cv.video_frame_interpolation.VFINet_arch import VFINet +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() +__all__ = ['VFINetForVideoFrameInterpolation'] + + +def convert(param): + return { + k.replace('module.', ''): v + for k, v in param.items() if 'module.' in k + } + + +@MODELS.register_module( + Tasks.video_frame_interpolation, + module_name=Models.video_frame_interpolation) +class VFINetForVideoFrameInterpolation(TorchModel): + + def __init__(self, model_dir: str, *args, **kwargs): + """initialize the video frame-interpolation model from the `model_dir` path. + + Args: + model_dir (str): the model path. + + """ + super().__init__(model_dir, *args, **kwargs) + + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + + self.model_dir = model_dir + self.config = Config.from_file( + os.path.join(self.model_dir, ModelFile.CONFIGURATION)) + + flownet_path = os.path.join(model_dir, 'raft-sintel.pt') + internet_path = os.path.join(model_dir, 'interpnet.pt') + + self.model = VFINet(self.config.model.network, Ds_flag=True) + self._load_pretrained(flownet_path, internet_path) + + def _load_pretrained(self, flownet_path, internet_path): + state_dict_flownet = torch.load( + flownet_path, map_location=self._device) + state_dict_internet = torch.load( + internet_path, map_location=self._device) + + self.model.flownet.load_state_dict( + convert(state_dict_flownet), strict=True) + self.model.internet.load_state_dict( + convert(state_dict_internet), strict=True) + self.model.internet_Ds.load_state_dict( + convert(state_dict_internet), strict=True) + logger.info('load model done.') + + def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]: + return {'output': self.model(input)} + + def _evaluate_postprocess(self, input: Tensor, + target: Tensor) -> Dict[str, list]: + preds = self.model(input) + del input + torch.cuda.empty_cache() + return {'pred': preds, 'target': target} + + def forward(self, inputs: Dict[str, + Tensor]) -> Dict[str, Union[list, Tensor]]: + """return the result by the model + + Args: + inputs (Tensor): the preprocessed data + + Returns: + Dict[str, Tensor]: results + """ + + if 'target' in inputs: + return self._evaluate_postprocess(**inputs) + else: + return self._inference_forward(**inputs) diff --git a/modelscope/models/cv/video_frame_interpolation/__init__.py b/modelscope/models/cv/video_frame_interpolation/__init__.py new file mode 100644 index 00000000..657a375a --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/__init__.py @@ -0,0 +1,20 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .VFINet_arch import VFINet + +else: + _import_structure = {'VFINet_arch': ['VFINet']} + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/models/cv/video_frame_interpolation/flow_model/__init__.py b/modelscope/models/cv/video_frame_interpolation/flow_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_frame_interpolation/flow_model/corr.py b/modelscope/models/cv/video_frame_interpolation/flow_model/corr.py new file mode 100644 index 00000000..86009405 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/flow_model/corr.py @@ -0,0 +1,92 @@ +# The implementation is adopted from RAFT, +# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT + +import torch +import torch.nn.functional as F + +from modelscope.models.cv.video_frame_interpolation.utils.utils import ( + bilinear_sampler, coords_grid) + + +class CorrBlock: + + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch * h1 * w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels - 1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) + + centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht * wd) + fmap2 = fmap2.view(batch, dim, ht * wd) + + corr = torch.matmul(fmap1.transpose(1, 2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class AlternateCorrBlock: + + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + dim = self.pyramid[0][0].shape[1] + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / torch.sqrt(torch.tensor(dim).float()) diff --git a/modelscope/models/cv/video_frame_interpolation/flow_model/extractor.py b/modelscope/models/cv/video_frame_interpolation/flow_model/extractor.py new file mode 100644 index 00000000..c0ebef47 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/flow_model/extractor.py @@ -0,0 +1,288 @@ +# The implementation is adopted from RAFT, +# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm3) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BottleneckBlock(nn.Module): + + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_planes, planes // 4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d( + planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes // 4) + self.norm2 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes // 4) + self.norm3 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm( + num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes // 4) + self.norm2 = nn.BatchNorm2d(planes // 4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes // 4) + self.norm2 = nn.InstanceNorm2d(planes // 4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), + self.norm4) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Module): + + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock( + self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_( + m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, + (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock( + self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/modelscope/models/cv/video_frame_interpolation/flow_model/raft.py b/modelscope/models/cv/video_frame_interpolation/flow_model/raft.py new file mode 100644 index 00000000..87b7a2ed --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/flow_model/raft.py @@ -0,0 +1,157 @@ +# The implementation is adopted from RAFT, +# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.video_frame_interpolation.flow_model.corr import ( + AlternateCorrBlock, CorrBlock) +from modelscope.models.cv.video_frame_interpolation.flow_model.extractor import ( + BasicEncoder, SmallEncoder) +from modelscope.models.cv.video_frame_interpolation.flow_model.update import ( + BasicUpdateBlock, SmallUpdateBlock) +from modelscope.models.cv.video_frame_interpolation.utils.utils import ( + bilinear_sampler, coords_grid, upflow8) + +autocast = torch.cuda.amp.autocast + + +class RAFT(nn.Module): + + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + self.args.corr_levels = 4 + self.args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + self.args.corr_levels = 4 + self.args.corr_radius = 4 + + if 'dropout' not in self.args: + self.args.dropout = 0 + + if 'alternate_corr' not in self.args: + self.args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder( + output_dim=128, norm_fn='instance', dropout=self.args.dropout) + self.cnet = SmallEncoder( + output_dim=hdim + cdim, + norm_fn='none', + dropout=self.args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder( + output_dim=256, norm_fn='instance', dropout=self.args.dropout) + self.cnet = BasicEncoder( + output_dim=hdim + cdim, + norm_fn='batch', + dropout=self.args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H // 8, W // 8, device=img.device) + coords1 = coords_grid(N, H // 8, W // 8, device=img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3, 3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8 * H, 8 * W) + + def forward(self, + image1, + image2, + iters=12, + flow_init=None, + upsample=True, + test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock( + fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block( + net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/modelscope/models/cv/video_frame_interpolation/flow_model/update.py b/modelscope/models/cv/video_frame_interpolation/flow_model/update.py new file mode 100644 index 00000000..29a20db1 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/flow_model/update.py @@ -0,0 +1,160 @@ +# The implementation is adopted from RAFT, +# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class ConvGRU(nn.Module): + + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) + + h = (1 - z) * h + z * q + return h + + +class SepConvGRU(nn.Module): + + def __init__(self, hidden_dim=128, input_dim=192 + 128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convr1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + self.convq1 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2)) + + self.convz2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convr2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + self.convq2 = nn.Conv2d( + hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) + h = (1 - z) * h + z * q + + return h + + +class SmallMotionEncoder(nn.Module): + + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicMotionEncoder(nn.Module): + + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class SmallUpdateBlock(nn.Module): + + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + + +class BasicUpdateBlock(nn.Module): + + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU( + hidden_dim=hidden_dim, input_dim=128 + hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True), + nn.Conv2d(256, 64 * 9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/IFNet_swin.py b/modelscope/models/cv/video_frame_interpolation/interp_model/IFNet_swin.py new file mode 100644 index 00000000..3e82bde2 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/interp_model/IFNet_swin.py @@ -0,0 +1,434 @@ +# Part of the implementation is borrowed and modified from RIFE, +# publicly available at https://github.com/megvii-research/ECCV2022-RIFE + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import trunc_normal_ + +from modelscope.models.cv.video_frame_interpolation.interp_model.transformer_layers import ( + RTFL, PatchEmbed, PatchUnEmbed) + +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +backwarp_tenGrid = {} + + +def warp(tenInput, tenFlow): + k = (str(tenFlow.device), str(tenFlow.size())) + if k not in backwarp_tenGrid: + tenHorizontal = torch.linspace( + -1.0, 1.0, tenFlow.shape[3], device=device).view( + 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, + tenFlow.shape[2], -1) + tenVertical = torch.linspace( + -1.0, 1.0, tenFlow.shape[2], + device=device).view(1, 1, tenFlow.shape[2], + 1).expand(tenFlow.shape[0], -1, -1, + tenFlow.shape[3]) + backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], + 1).to(device) + + tmp1 = tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0) + tmp2 = tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) + tenFlow = torch.cat([tmp1, tmp2], 1) + + g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) + return torch.nn.functional.grid_sample( + input=tenInput, + grid=torch.clamp(g, -1, 1), + mode='bilinear', + padding_mode='border', + align_corners=True) + + +def conv_wo_act(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True), ) + + +def conv(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=True), nn.PReLU(out_planes)) + + +def conv_bn(in_planes, + out_planes, + kernel_size=3, + stride=1, + padding=1, + dilation=1): + return nn.Sequential( + nn.Conv2d( + in_planes, + out_planes, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + bias=False), nn.BatchNorm2d(out_planes), nn.PReLU(out_planes)) + + +class TransModel(nn.Module): + + def __init__(self, + img_size=64, + patch_size=1, + embed_dim=64, + depths=[[3, 3]], + num_heads=[[2, 2]], + window_size=4, + mlp_ratio=2, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + resi_connection='1conv', + use_crossattn=[[[False, False, False, False], + [True, True, True, True]]]): + super(TransModel, self).__init__() + self.embed_dim = embed_dim + self.ape = ape + self.patch_norm = patch_norm + self.num_features = embed_dim + self.mlp_ratio = mlp_ratio + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # merge non-overlapping patches into image + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=embed_dim, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim)) + trunc_normal_(self.absolute_pos_embed, std=.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + dpr0 = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[0])) + ] # stochastic depth decay rule + + self.layers0 = nn.ModuleList() + num_layers = len(depths[0]) + for i_layer in range(num_layers): + layer = RTFL( + dim=embed_dim, + input_resolution=(patches_resolution[0], + patches_resolution[1]), + depth=depths[0][i_layer], + num_heads=num_heads[0][i_layer], + window_size=window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr0[sum(depths[0][:i_layer]):sum(depths[0][:i_layer + + 1])], + norm_layer=norm_layer, + downsample=None, + use_checkpoint=use_checkpoint, + img_size=(img_size[0], img_size[1]), + patch_size=patch_size, + resi_connection=resi_connection, + use_crossattn=use_crossattn[0][i_layer]) + self.layers0.append(layer) + + self.norm = norm_layer(self.num_features) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {'absolute_pos_embed'} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {'relative_position_bias_table'} + + def forward_features(self, x, layers): + x_size = (x.shape[2], x.shape[3]) + x = self.patch_embed(x) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + + if isinstance(layers, nn.ModuleList): + for layer in layers: + x = layer(x, x_size) + else: + x = layers(x, x_size) + + x = self.norm(x) # B L C + x = self.patch_unembed(x, x_size) + + return x + + def forward(self, x): + out = self.forward_features(x, self.layers0) + return out + + +class IFBlock(nn.Module): + + def __init__(self, in_planes, scale=1, c=64): + super(IFBlock, self).__init__() + self.scale = scale + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + conv(c, c, 3, 1, 1), + ) + + self.trans = TransModel( + img_size=(128 // scale, 128 // scale), + patch_size=1, + embed_dim=c, + depths=[[3, 3]], + num_heads=[[2, 2]]) + + self.conv1 = nn.Sequential( + conv(c, c, 3, 1, 1), + conv(c, c, 3, 1, 1), + ) + + self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1) + + self.conv2 = nn.Conv2d(4, 4, 3, 1, 1) + + def forward(self, x, flow0, flow1): + if self.scale != 1: + x = F.interpolate( + x, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) + flow0 = F.interpolate( + flow0, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) * (1. / self.scale) + flow1 = F.interpolate( + flow1, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) * (1. / self.scale) + + x = torch.cat((x, flow0, flow1), 1) + + x = self.conv0(x) + x = self.trans(x) + x = self.conv1(x) + x + + # upsample 2.0 + x = self.up(x) + + # upsample 2.0 + x = self.conv2(x) + flow = F.interpolate( + x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0 + + if self.scale != 1: + flow = F.interpolate( + flow, + scale_factor=self.scale, + mode='bilinear', + align_corners=False) * self.scale + + flow0 = flow[:, :2, :, :] + flow1 = flow[:, 2:, :, :] + + return flow0, flow1 + + +class IFBlock_wo_Swin(nn.Module): + + def __init__(self, in_planes, scale=1, c=64): + super(IFBlock_wo_Swin, self).__init__() + self.scale = scale + self.conv0 = nn.Sequential( + conv(in_planes, c // 2, 3, 2, 1), + conv(c // 2, c, 3, 2, 1), + ) + + self.convblock1 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c)) + self.convblock2 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c)) + + self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1) + + self.conv2 = nn.Conv2d(4, 4, 3, 1, 1) + + def forward(self, x, flow0, flow1): + if self.scale != 1: + x = F.interpolate( + x, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) + flow0 = F.interpolate( + flow0, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) * (1. / self.scale) + flow1 = F.interpolate( + flow1, + scale_factor=1. / self.scale, + mode='bilinear', + align_corners=False) * (1. / self.scale) + + x = torch.cat((x, flow0, flow1), 1) + + x = self.conv0(x) + x = self.convblock1(x) + x + x = self.convblock2(x) + x + # upsample 2.0 + x = self.up(x) + + # upsample 2.0 + x = self.conv2(x) + flow = F.interpolate( + x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0 + + if self.scale != 1: + flow = F.interpolate( + flow, + scale_factor=self.scale, + mode='bilinear', + align_corners=False) * self.scale + + flow0 = flow[:, :2, :, :] + flow1 = flow[:, 2:, :, :] + + return flow0, flow1 + + +class IFNet(nn.Module): + + def __init__(self): + super(IFNet, self).__init__() + self.block1 = IFBlock_wo_Swin(16, scale=4, c=128) + self.block2 = IFBlock(16, scale=2, c=64) + self.block3 = IFBlock(16, scale=1, c=32) + + # flow0: flow from img0 to img1 + # flow1: flow from img1 to img0 + def forward(self, img0, img1, flow0, flow1, sc_mode=2): + + if sc_mode == 0: + sc = 0.25 + elif sc_mode == 1: + sc = 0.5 + else: + sc = 1 + + if sc != 1: + img0_sc = F.interpolate( + img0, scale_factor=sc, mode='bilinear', align_corners=False) + img1_sc = F.interpolate( + img1, scale_factor=sc, mode='bilinear', align_corners=False) + flow0_sc = F.interpolate( + flow0, scale_factor=sc, mode='bilinear', + align_corners=False) * sc + flow1_sc = F.interpolate( + flow1, scale_factor=sc, mode='bilinear', + align_corners=False) * sc + else: + img0_sc = img0 + img1_sc = img1 + flow0_sc = flow0 + flow1_sc = flow1 + + warped_img0 = warp(img1_sc, flow0_sc) # -> img0 + warped_img1 = warp(img0_sc, flow1_sc) # -> img1 + flow0_1, flow1_1 = self.block1( + torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1), + flow0_sc, flow1_sc) + F0_2 = (flow0_sc + flow0_1) + F1_2 = (flow1_sc + flow1_1) + + warped_img0 = warp(img1_sc, F0_2) # -> img0 + warped_img1 = warp(img0_sc, F1_2) # -> img1 + flow0_2, flow1_2 = self.block2( + torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1), F0_2, + F1_2) + F0_3 = (F0_2 + flow0_2) + F1_3 = (F1_2 + flow1_2) + + warped_img0 = warp(img1_sc, F0_3) # -> img0 + warped_img1 = warp(img0_sc, F1_3) # -> img1 + flow0_3, flow1_3 = self.block3( + torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), dim=1), + F0_3, F1_3) + flow_res_0 = flow0_1 + flow0_2 + flow0_3 + flow_res_1 = flow1_1 + flow1_2 + flow1_3 + + if sc != 1: + flow_res_0 = F.interpolate( + flow_res_0, + scale_factor=1 / sc, + mode='bilinear', + align_corners=False) / sc + flow_res_1 = F.interpolate( + flow_res_1, + scale_factor=1 / sc, + mode='bilinear', + align_corners=False) / sc + + F0_4 = flow0 + flow_res_0 + F1_4 = flow1 + flow_res_1 + + return F0_4, F1_4 diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/UNet.py b/modelscope/models/cv/video_frame_interpolation/interp_model/UNet.py new file mode 100644 index 00000000..34b5be19 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/interp_model/UNet.py @@ -0,0 +1,127 @@ +# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class down(nn.Module): + + def __init__(self, inChannels, outChannels, filterSize): + super(down, self).__init__() + self.conv1 = nn.Conv2d( + inChannels, + outChannels, + filterSize, + stride=1, + padding=int((filterSize - 1) / 2)) + self.conv2 = nn.Conv2d( + outChannels, + outChannels, + filterSize, + stride=1, + padding=int((filterSize - 1) / 2)) + + def forward(self, x): + x = F.avg_pool2d(x, 2) + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) + x = F.leaky_relu(self.conv2(x), negative_slope=0.1) + return x + + +class up(nn.Module): + + def __init__(self, inChannels, outChannels): + super(up, self).__init__() + self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1) + self.conv2 = nn.Conv2d( + 2 * outChannels, outChannels, 3, stride=1, padding=1) + + def forward(self, x, skpCn): + x = F.interpolate( + x, + size=[skpCn.size(2), skpCn.size(3)], + mode='bilinear', + align_corners=False) + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) + x = F.leaky_relu( + self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1) + return x + + +class Small_UNet(nn.Module): + + def __init__(self, inChannels, outChannels): + super(Small_UNet, self).__init__() + self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3) + self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3) + self.down1 = down(32, 64, 5) + self.down2 = down(64, 128, 3) + self.down3 = down(128, 128, 3) + self.up1 = up(128, 128) + self.up2 = up(128, 64) + self.up3 = up(64, 32) + self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1) + + def forward(self, x): + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) + s1 = F.leaky_relu(self.conv2(x), negative_slope=0.1) + s2 = self.down1(s1) + s3 = self.down2(s2) + x = self.down3(s3) + x = self.up1(x, s3) + x = self.up2(x, s2) + x1 = self.up3(x, s1) # feature + x = self.conv3(x1) # flow + return x, x1 + + +class Small_UNet_Ds(nn.Module): + + def __init__(self, inChannels, outChannels): + super(Small_UNet_Ds, self).__init__() + self.conv1_1 = nn.Conv2d(inChannels, 32, 5, stride=1, padding=2) + self.conv1_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1) + self.conv2_1 = nn.Conv2d(32, 32, 3, stride=1, padding=1) + self.conv2_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1) + + self.down1 = down(32, 64, 5) + self.down2 = down(64, 128, 3) + self.down3 = down(128, 128, 3) + self.up1 = up(128, 128) + self.up2 = up(128, 64) + self.up3 = up(64, 32) + self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=1) + self.conv4 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1) + + def forward(self, x): + + x0 = F.leaky_relu(self.conv1_1(x), negative_slope=0.1) + x0 = F.leaky_relu(self.conv1_2(x0), negative_slope=0.1) + + x = F.interpolate( + x0, + size=[x0.size(2) // 2, x0.size(3) // 2], + mode='bilinear', + align_corners=False) + + x = F.leaky_relu(self.conv2_1(x), negative_slope=0.1) + s1 = F.leaky_relu(self.conv2_2(x), negative_slope=0.1) + + s2 = self.down1(s1) + s3 = self.down2(s2) + x = self.down3(s3) + + x = self.up1(x, s3) + x = self.up2(x, s2) + x1 = self.up3(x, s1) + + x1 = F.interpolate( + x1, + size=[x0.size(2), x0.size(3)], + mode='bilinear', + align_corners=False) + + x1 = F.leaky_relu(self.conv3(x1), negative_slope=0.1) # feature + x = self.conv4(x1) # flow + return x, x1 diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/__init__.py b/modelscope/models/cv/video_frame_interpolation/interp_model/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/flow_reversal.py b/modelscope/models/cv/video_frame_interpolation/interp_model/flow_reversal.py new file mode 100644 index 00000000..daac2ead --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/interp_model/flow_reversal.py @@ -0,0 +1,115 @@ +# The implementation is adopted from QVI, +# made publicly available at https://github.com/xuxy09/QVI + +# class WarpLayer warps image x based on optical flow flo. +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowReversal(nn.Module): + """docstring for WarpLayer""" + + def __init__(self, ): + super(FlowReversal, self).__init__() + + def forward(self, img, flo): + """ + -img: image (N, C, H, W) + -flo: optical flow (N, 2, H, W) + elements of flo is in [0, H] and [0, W] for dx, dy + + """ + + N, C, _, _ = img.size() + + # translate start-point optical flow to end-point optical flow + y = flo[:, 0:1:, :] + x = flo[:, 1:2, :, :] + + x = x.repeat(1, C, 1, 1) + y = y.repeat(1, C, 1, 1) + + x1 = torch.floor(x) + x2 = x1 + 1 + y1 = torch.floor(y) + y2 = y1 + 1 + + # firstly, get gaussian weights + w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2) + + # secondly, sample each weighted corner + img11, o11 = self.sample_one(img, x1, y1, w11) + img12, o12 = self.sample_one(img, x1, y2, w12) + img21, o21 = self.sample_one(img, x2, y1, w21) + img22, o22 = self.sample_one(img, x2, y2, w22) + + imgw = img11 + img12 + img21 + img22 + o = o11 + o12 + o21 + o22 + + return imgw, o + + def get_gaussian_weights(self, x, y, x1, x2, y1, y2): + w11 = torch.exp(-((x - x1)**2 + (y - y1)**2)) + w12 = torch.exp(-((x - x1)**2 + (y - y2)**2)) + w21 = torch.exp(-((x - x2)**2 + (y - y1)**2)) + w22 = torch.exp(-((x - x2)**2 + (y - y2)**2)) + + return w11, w12, w21, w22 + + def sample_one(self, img, shiftx, shifty, weight): + """ + Input: + -img (N, C, H, W) + -shiftx, shifty (N, c, H, W) + """ + + N, C, H, W = img.size() + + # flatten all (all restored as Tensors) + flat_shiftx = shiftx.view(-1) + flat_shifty = shifty.view(-1) + flat_basex = torch.arange( + 0, H, requires_grad=False).view(-1, 1)[None, + None].cuda().long().repeat( + N, C, 1, W).view(-1) + flat_basey = torch.arange( + 0, W, requires_grad=False).view(1, -1)[None, + None].cuda().long().repeat( + N, C, H, 1).view(-1) + flat_weight = weight.view(-1) + flat_img = img.view(-1) + + idxn = torch.arange( + 0, N, + requires_grad=False).view(N, 1, 1, + 1).long().cuda().repeat(1, C, H, + W).view(-1) + idxc = torch.arange( + 0, C, + requires_grad=False).view(1, C, 1, + 1).long().cuda().repeat(N, 1, H, + W).view(-1) + idxx = flat_shiftx.long() + flat_basex + idxy = flat_shifty.long() + flat_basey + + mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W) + + ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy) + ids_mask = torch.masked_select(ids, mask).clone().cuda() + + img_warp = torch.zeros([ + N * C * H * W, + ]).cuda() + img_warp.put_( + ids_mask, + torch.masked_select(flat_img * flat_weight, mask), + accumulate=True) + + one_warp = torch.zeros([ + N * C * H * W, + ]).cuda() + one_warp.put_( + ids_mask, torch.masked_select(flat_weight, mask), accumulate=True) + + return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W) diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/refinenet_arch.py b/modelscope/models/cv/video_frame_interpolation/interp_model/refinenet_arch.py new file mode 100644 index 00000000..1fa0136f --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/interp_model/refinenet_arch.py @@ -0,0 +1,488 @@ +# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from modelscope.models.cv.video_frame_interpolation.interp_model.flow_reversal import \ + FlowReversal +from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \ + IFNet +from modelscope.models.cv.video_frame_interpolation.interp_model.UNet import \ + Small_UNet_Ds + + +class AcFusionLayer(nn.Module): + + def __init__(self, ): + super(AcFusionLayer, self).__init__() + + def forward(self, flo10, flo12, flo21, flo23, t=0.5): + return 0.5 * ((t + t**2) * flo12 - (t - t**2) * flo10), \ + 0.5 * (((1 - t) + (1 - t)**2) * flo21 - ((1 - t) - (1 - t)**2) * flo23) + # return 0.375 * flo12 - 0.125 * flo10, 0.375 * flo21 - 0.125 * flo23 + + +class Get_gradient(nn.Module): + + def __init__(self): + super(Get_gradient, self).__init__() + kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]] + kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]] + kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0) + kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0) + self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) + self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) + + def forward(self, x): + x0 = x[:, 0] # R + x1 = x[:, 1] # G + x2 = x[:, 2] # B + x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1) + x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1) + + x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1) + x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1) + + x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1) + x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1) + + x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6) + x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6) + x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6) + + x = torch.cat([x0, x1, x2], dim=1) + return x + + +class LowPassFilter(nn.Module): + + def __init__(self): + super(LowPassFilter, self).__init__() + kernel_lpf = [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1], + [1, 1, 1, 1, 1, 1, 1]] + + kernel_lpf = torch.FloatTensor(kernel_lpf).unsqueeze(0).unsqueeze( + 0) / 49 + + self.weight_lpf = nn.Parameter(data=kernel_lpf, requires_grad=False) + + def forward(self, x): + x0 = x[:, 0] + x1 = x[:, 1] + y0 = F.conv2d(x0.unsqueeze(1), self.weight_lpf, padding=3) + y1 = F.conv2d(x1.unsqueeze(1), self.weight_lpf, padding=3) + + y = torch.cat([y0, y1], dim=1) + + return y + + +def backwarp(img, flow): + _, _, H, W = img.size() + + u = flow[:, 0, :, :] + v = flow[:, 1, :, :] + + gridX, gridY = np.meshgrid(np.arange(W), np.arange(H)) + + gridX = torch.tensor( + gridX, + requires_grad=False, + ).cuda() + gridY = torch.tensor( + gridY, + requires_grad=False, + ).cuda() + x = gridX.unsqueeze(0).expand_as(u).float() + u + y = gridY.unsqueeze(0).expand_as(v).float() + v + + x = 2 * (x / (W - 1) - 0.5) + y = 2 * (y / (H - 1) - 0.5) + + grid = torch.stack((x, y), dim=3) + + imgOut = torch.nn.functional.grid_sample( + img, grid, padding_mode='border', align_corners=True) + + return imgOut + + +class SmallMaskNet(nn.Module): + """A three-layer network for predicting mask""" + + def __init__(self, input, output): + super(SmallMaskNet, self).__init__() + self.conv1 = nn.Conv2d(input, 32, 5, padding=2) + self.conv2 = nn.Conv2d(32, 16, 3, padding=1) + self.conv3 = nn.Conv2d(16, output, 3, padding=1) + + def forward(self, x): + x = F.leaky_relu(self.conv1(x), negative_slope=0.1) + x = F.leaky_relu(self.conv2(x), negative_slope=0.1) + x = self.conv3(x) + return x + + +class StaticMaskNet(nn.Module): + """static mask""" + + def __init__(self, input, output): + super(StaticMaskNet, self).__init__() + + modules_body = [] + modules_body.append( + nn.Conv2d( + in_channels=input, + out_channels=32, + kernel_size=3, + stride=1, + padding=1)) + modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1)) + modules_body.append( + nn.Conv2d( + in_channels=32, + out_channels=32, + kernel_size=3, + stride=1, + padding=1)) + modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1)) + modules_body.append( + nn.Conv2d( + in_channels=32, + out_channels=16, + kernel_size=3, + stride=1, + padding=1)) + modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1)) + modules_body.append( + nn.Conv2d( + in_channels=16, + out_channels=16, + kernel_size=3, + stride=1, + padding=1)) + modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1)) + modules_body.append( + nn.Conv2d( + in_channels=16, + out_channels=output, + kernel_size=3, + stride=1, + padding=1)) + modules_body.append(nn.Sigmoid()) + + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + y = self.body(x) + return y + + +def tensor_erode(bin_img, ksize=5): + B, C, H, W = bin_img.shape + pad = (ksize - 1) // 2 + bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0) + + patches = bin_img.unfold(dimension=2, size=ksize, step=1) + patches = patches.unfold(dimension=3, size=ksize, step=1) + + eroded, _ = patches.reshape(B, C, H, W, -1).max(dim=-1) + return eroded + + +class QVI_inter_Ds(nn.Module): + """Given flow, implement Quadratic Video Interpolation""" + + def __init__(self, debug_en=False, is_training=False): + super(QVI_inter_Ds, self).__init__() + self.acc = AcFusionLayer() + self.fwarp = FlowReversal() + self.refinenet = Small_UNet_Ds(20, 8) + self.masknet = SmallMaskNet(38, 1) + + self.staticnet = StaticMaskNet(56, 1) + self.lpfilter = LowPassFilter() + + self.get_grad = Get_gradient() + self.debug_en = debug_en + self.is_training = is_training + + def fill_flow_hole(self, ft, norm, ft_fill): + (N, C, H, W) = ft.shape + ft[norm == 0] = ft_fill[norm == 0] + + ft_1 = self.lpfilter(ft.clone()) + ft_ds = torch.nn.functional.interpolate( + input=ft_1, + size=(H // 4, W // 4), + mode='bilinear', + align_corners=False) + ft_up = torch.nn.functional.interpolate( + input=ft_ds, size=(H, W), mode='bilinear', align_corners=False) + + ft[norm == 0] = ft_up[norm == 0] + + return ft + + def forward(self, F10_Ds, F12_Ds, F21_Ds, F23_Ds, I1_Ds, I2_Ds, I1, I2, t): + if F12_Ds is None or F21_Ds is None: + return I1 + + if F10_Ds is not None and F23_Ds is not None: + F1t_Ds, F2t_Ds = self.acc(F10_Ds, F12_Ds, F21_Ds, F23_Ds, t) + + else: + F1t_Ds = t * F12_Ds + F2t_Ds = (1 - t) * F21_Ds + + # Flow Reversal + F1t_Ds2 = F.interpolate( + F1t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3 + F2t_Ds2 = F.interpolate( + F2t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3 + Ft1_Ds2, norm1_Ds2 = self.fwarp(F1t_Ds2, F1t_Ds2) + Ft1_Ds2 = -Ft1_Ds2 + Ft2_Ds2, norm2_Ds2 = self.fwarp(F2t_Ds2, F2t_Ds2) + Ft2_Ds2 = -Ft2_Ds2 + + Ft1_Ds2[norm1_Ds2 > 0] \ + = Ft1_Ds2[norm1_Ds2 > 0] / norm1_Ds2[norm1_Ds2 > 0].clone() + Ft2_Ds2[norm2_Ds2 > 0] \ + = Ft2_Ds2[norm2_Ds2 > 0] / norm2_Ds2[norm2_Ds2 > 0].clone() + if 1: + Ft1_Ds2_fill = -F1t_Ds2 + Ft2_Ds2_fill = -F2t_Ds2 + Ft1_Ds2 = self.fill_flow_hole(Ft1_Ds2, norm1_Ds2, Ft1_Ds2_fill) + Ft2_Ds2 = self.fill_flow_hole(Ft2_Ds2, norm2_Ds2, Ft2_Ds2_fill) + + Ft1_Ds = F.interpolate( + Ft1_Ds2, size=[F1t_Ds.size(2), F1t_Ds.size(3)], mode='nearest') * 3 + Ft2_Ds = F.interpolate( + Ft2_Ds2, size=[F2t_Ds.size(2), F2t_Ds.size(3)], mode='nearest') * 3 + + I1t_Ds = backwarp(I1_Ds, Ft1_Ds) + I2t_Ds = backwarp(I2_Ds, Ft2_Ds) + + output_Ds, feature_Ds = self.refinenet( + torch.cat( + [I1_Ds, I2_Ds, I1t_Ds, I2t_Ds, F12_Ds, F21_Ds, Ft1_Ds, Ft2_Ds], + dim=1)) + + # Adaptive filtering + Ft1r_Ds = backwarp( + Ft1_Ds, 10 * torch.tanh(output_Ds[:, 4:6])) + output_Ds[:, :2] + Ft2r_Ds = backwarp( + Ft2_Ds, 10 * torch.tanh(output_Ds[:, 6:8])) + output_Ds[:, 2:4] + + # warping and fusing + I1tf_Ds = backwarp(I1_Ds, Ft1r_Ds) + I2tf_Ds = backwarp(I2_Ds, Ft2r_Ds) + + G1_Ds = self.get_grad(I1_Ds) + G2_Ds = self.get_grad(I2_Ds) + G1tf_Ds = backwarp(G1_Ds, Ft1r_Ds) + G2tf_Ds = backwarp(G2_Ds, Ft2r_Ds) + + M_Ds = torch.sigmoid( + self.masknet(torch.cat([I1tf_Ds, I2tf_Ds, feature_Ds], + dim=1))).repeat(1, 3, 1, 1) + + Ft1r = F.interpolate( + Ft1r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False) + Ft2r = F.interpolate( + Ft2r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False) + + I1tf = backwarp(I1, Ft1r) + I2tf = backwarp(I2, Ft2r) + + M = F.interpolate( + M_Ds, scale_factor=2, mode='bilinear', align_corners=False) + + # fuse + It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \ + / ((1 - t) * M + t * (1 - M)).clone() + + # static blending + It_static = (1 - t) * I1 + t * I2 + tmp = torch.cat((I1tf_Ds, I2tf_Ds, G1tf_Ds, G2tf_Ds, I1_Ds, I2_Ds, + G1_Ds, G2_Ds, feature_Ds), + dim=1) + M_static_Ds = self.staticnet(tmp) + M_static_dilate = tensor_erode(M_static_Ds) + M_static_dilate = tensor_erode(M_static_dilate) + M_static = F.interpolate( + M_static_dilate, + scale_factor=2, + mode='bilinear', + align_corners=False) + + It_warp = (1 - M_static) * It_warp + M_static * It_static + + if self.is_training: + return It_warp, Ft1r, Ft2r + else: + if self.debug_en: + return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r + else: + return It_warp + + +class QVI_inter(nn.Module): + """Given flow, implement Quadratic Video Interpolation""" + + def __init__(self, debug_en=False, is_training=False): + super(QVI_inter, self).__init__() + self.acc = AcFusionLayer() + self.fwarp = FlowReversal() + self.refinenet = Small_UNet_Ds(20, 8) + self.masknet = SmallMaskNet(38, 1) + + self.staticnet = StaticMaskNet(56, 1) + self.lpfilter = LowPassFilter() + + self.get_grad = Get_gradient() + self.debug_en = debug_en + self.is_training = is_training + + def fill_flow_hole(self, ft, norm, ft_fill): + (N, C, H, W) = ft.shape + ft[norm == 0] = ft_fill[norm == 0] + + ft_1 = self.lpfilter(ft.clone()) + ft_ds = torch.nn.functional.interpolate( + input=ft_1, + size=(H // 4, W // 4), + mode='bilinear', + align_corners=False) + ft_up = torch.nn.functional.interpolate( + input=ft_ds, size=(H, W), mode='bilinear', align_corners=False) + + ft[norm == 0] = ft_up[norm == 0] + + return ft + + def forward(self, F10, F12, F21, F23, I1, I2, t): + if F12 is None or F21 is None: + return I1 + + if F10 is not None and F23 is not None: + F1t, F2t = self.acc(F10, F12, F21, F23, t) + + else: + F1t = t * F12 + F2t = (1 - t) * F21 + + # Flow Reversal + F1t_Ds = F.interpolate(F1t, scale_factor=1.0 / 3, mode='nearest') / 3 + F2t_Ds = F.interpolate(F2t, scale_factor=1.0 / 3, mode='nearest') / 3 + Ft1_Ds, norm1_Ds = self.fwarp(F1t_Ds, F1t_Ds) + Ft1_Ds = -Ft1_Ds + Ft2_Ds, norm2_Ds = self.fwarp(F2t_Ds, F2t_Ds) + Ft2_Ds = -Ft2_Ds + + Ft1_Ds[norm1_Ds > 0] \ + = Ft1_Ds[norm1_Ds > 0] / norm1_Ds[norm1_Ds > 0].clone() + Ft2_Ds[norm2_Ds > 0] \ + = Ft2_Ds[norm2_Ds > 0] / norm2_Ds[norm2_Ds > 0].clone() + if 1: + Ft1_fill = -F1t_Ds + Ft2_fill = -F2t_Ds + Ft1_Ds = self.fill_flow_hole(Ft1_Ds, norm1_Ds, Ft1_fill) + Ft2_Ds = self.fill_flow_hole(Ft2_Ds, norm2_Ds, Ft2_fill) + + Ft1 = F.interpolate( + Ft1_Ds, size=[F1t.size(2), F1t.size(3)], mode='nearest') * 3 + Ft2 = F.interpolate( + Ft2_Ds, size=[F2t.size(2), F2t.size(3)], mode='nearest') * 3 + + I1t = backwarp(I1, Ft1) + I2t = backwarp(I2, Ft2) + + output, feature = self.refinenet( + torch.cat([I1, I2, I1t, I2t, F12, F21, Ft1, Ft2], dim=1)) + + # Adaptive filtering + Ft1r = backwarp(Ft1, 10 * torch.tanh(output[:, 4:6])) + output[:, :2] + Ft2r = backwarp(Ft2, 10 * torch.tanh(output[:, 6:8])) + output[:, 2:4] + + # warping and fusing + I1tf = backwarp(I1, Ft1r) + I2tf = backwarp(I2, Ft2r) + + M = torch.sigmoid( + self.masknet(torch.cat([I1tf, I2tf, feature], + dim=1))).repeat(1, 3, 1, 1) + + It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \ + / ((1 - t) * M + t * (1 - M)).clone() + + G1 = self.get_grad(I1) + G2 = self.get_grad(I2) + G1tf = backwarp(G1, Ft1r) + G2tf = backwarp(G2, Ft2r) + + # static blending + It_static = (1 - t) * I1 + t * I2 + M_static = self.staticnet( + torch.cat([I1tf, I2tf, G1tf, G2tf, I1, I2, G1, G2, feature], + dim=1)) + M_static_dilate = tensor_erode(M_static) + M_static_dilate = tensor_erode(M_static_dilate) + It_warp = (1 - M_static_dilate) * It_warp + M_static_dilate * It_static + + if self.is_training: + return It_warp, Ft1r, Ft2r + else: + if self.debug_en: + return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r + else: + return It_warp + + +class InterpNetDs(nn.Module): + + def __init__(self, debug_en=False, is_training=False): + super(InterpNetDs, self).__init__() + self.ifnet = IFNet() + self.internet = QVI_inter_Ds( + debug_en=debug_en, is_training=is_training) + + def forward(self, + img1, + img2, + F10_up, + F12_up, + F21_up, + F23_up, + UHD=2, + timestep=0.5): + F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD) + It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep) + + return It_warp + + +class InterpNet(nn.Module): + + def __init__(self, debug_en=False, is_training=False): + super(InterpNet, self).__init__() + self.ifnet = IFNet() + self.internet = QVI_inter(debug_en=debug_en, is_training=is_training) + + def forward(self, + img1, + img2, + F10_up, + F12_up, + F21_up, + F23_up, + UHD=2, + timestep=0.5): + F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD) + It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep) + + return It_warp diff --git a/modelscope/models/cv/video_frame_interpolation/interp_model/transformer_layers.py b/modelscope/models/cv/video_frame_interpolation/interp_model/transformer_layers.py new file mode 100644 index 00000000..81ce114b --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/interp_model/transformer_layers.py @@ -0,0 +1,989 @@ +# The implementation is adopted from VFIformer, +# made publicly available at https://github.com/dvlab-research/VFIformer + +# ----------------------------------------------------------------------------------- +# modified from: +# SwinIR: Image Restoration Using Swin Transformer, https://github.com/JingyunLiang/SwinIR +# ----------------------------------------------------------------------------------- + +import functools +import math +import sys + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + + +class Mlp(nn.Module): + + def __init__(self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, + C) + windows = x.permute(0, 1, 3, 2, 4, + 5).contiguous().view(-1, window_size, window_size, C) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view(B, H // window_size, W // window_size, window_size, + window_size, -1) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowCrossAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0., + proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table_x = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + self.relative_position_bias_table_y = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), + num_heads)) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = coords_flatten[:, :, + None] - coords_flatten[:, + None, :] # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, + 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer('relative_position_index', + relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.merge1 = nn.Linear(dim * 2, dim) + self.merge2 = nn.Linear(dim, dim) + self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True) + self.proj = nn.Linear(dim, dim) + + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table_x, std=.02) + trunc_normal_(self.relative_position_bias_table_y, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, y, mask_x=None, mask_y=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[ + 2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1).contiguous()) + + relative_position_bias = self.relative_position_bias_table_x[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask_x is not None: + nW = mask_x.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask_x.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + + B_, N, C = y.shape + kv = self.kv(y).reshape(B_, N, 2, self.num_heads, + C // self.num_heads).permute(2, 0, 3, 1, 4) + k, v = kv[0], kv[ + 1] # make torchscript happy (cannot use tensor as tuple) + + attn = (q @ k.transpose(-2, -1).contiguous()) + + relative_position_bias = self.relative_position_bias_table_y[ + self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask_y is not None: + nW = mask_y.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, + N) + mask_y.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + y = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous() + + x = self.merge2(self.act(self.merge1(torch.cat([x, y], dim=-1)))) + x + + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class TFL(nn.Module): + + def __init__(self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + use_crossattn=False): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.use_crossattn = use_crossattn + + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size' + + self.norm1 = norm_layer(dim) + if not use_crossattn: + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + else: + self.attn = WindowCrossAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop) + + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop) + + if self.shift_size > 0: + if not use_crossattn: + attn_mask = self.calculate_mask(self.input_resolution) + self.register_buffer('attn_mask', attn_mask) + else: + attn_mask_x = self.calculate_mask(self.input_resolution) + attn_mask_y = self.calculate_mask2(self.input_resolution) + self.register_buffer('attn_mask_x', attn_mask_x) + self.register_buffer('attn_mask_y', attn_mask_y) + + else: + if not use_crossattn: + attn_mask = None + self.register_buffer('attn_mask', attn_mask) + else: + attn_mask_x = None + attn_mask_y = None + self.register_buffer('attn_mask_x', attn_mask_x) + self.register_buffer('attn_mask_y', attn_mask_y) + + def calculate_mask(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + return attn_mask + + def calculate_mask2(self, x_size): + # calculate attention mask for SW-MSA + H, W = x_size + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + w_slices = (slice(0, -self.window_size), + slice(-self.window_size, + -self.shift_size), slice(-self.shift_size, None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, + self.window_size * self.window_size) + + # downscale + img_mask_down = F.interpolate( + img_mask.permute(0, 3, 1, 2).contiguous(), + scale_factor=0.5, + mode='bilinear', + align_corners=False) + img_mask_down = F.pad( + img_mask_down, (self.window_size // 4, self.window_size // 4, + self.window_size // 4, self.window_size // 4), + mode='reflect') + mask_windows_down = F.unfold( + img_mask_down, + kernel_size=self.window_size, + dilation=1, + padding=0, + stride=self.window_size // 2) + mask_windows_down = mask_windows_down.view( + self.window_size * self.window_size, + -1).permute(1, 0).contiguous() # nW, window_size*window_size + + attn_mask = mask_windows_down.unsqueeze(1) - mask_windows.unsqueeze( + 2) # nW, window_size*window_size, window_size*window_size + attn_mask = attn_mask.masked_fill(attn_mask != 0, + float(-100.0)).masked_fill( + attn_mask == 0, float(0.0)) + + return attn_mask + + def forward(self, x, x_size): + H, W = x_size + B, L, C = x.shape + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + if not self.use_crossattn: + if self.input_resolution == x_size: + attn_windows = self.attn( + x_windows, + mask=self.attn_mask) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn( + x_windows, mask=self.calculate_mask(x_size).to(x.device)) + else: + shifted_x_down = F.interpolate( + shifted_x.permute(0, 3, 1, 2).contiguous(), + scale_factor=0.5, + mode='bilinear', + align_corners=False) + shifted_x_down = F.pad( + shifted_x_down, (self.window_size // 4, self.window_size // 4, + self.window_size // 4, self.window_size // 4), + mode='reflect') + x_windows_down = F.unfold( + shifted_x_down, + kernel_size=self.window_size, + dilation=1, + padding=0, + stride=self.window_size // 2) + x_windows_down = x_windows_down.view( + B, C, self.window_size * self.window_size, -1) + x_windows_down = x_windows_down.permute( + 0, 3, 2, + 1).contiguous().view(-1, self.window_size * self.window_size, + C) # nW*B, window_size*window_size, C + + if self.input_resolution == x_size: + attn_windows = self.attn( + x_windows, + x_windows_down, + mask_x=self.attn_mask_x, + mask_y=self.attn_mask_y + ) # nW*B, window_size*window_size, C + else: + attn_windows = self.attn( + x_windows, + x_windows_down, + mask_x=self.calculate_mask(x_size).to(x.device), + mask_y=self.calculate_mask2(x_size).to(x.device)) + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, + self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, + W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, + shifts=(self.shift_size, self.shift_size), + dims=(1, 2)) + else: + x = shifted_x + + x = x.view(B, H * W, C) + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \ + f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}' + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, 'input feature has wrong size' + assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.' + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self) -> str: + return f'input_resolution={self.input_resolution}, dim={self.dim}' + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.dim + flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + return flops + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + use_crossattn=None): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + if use_crossattn is None: + use_crossattn = [False for i in range(depth)] + + # build blocks + self.blocks = nn.ModuleList([ + TFL(dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + use_crossattn=use_crossattn[i]) for i in range(depth) + ]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, x_size): + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, x_size) + else: + x = blk(x, x_size) + if self.downsample is not None: + x = self.downsample(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}' + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + +class RTFL(nn.Module): + + def __init__(self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + img_size=224, + patch_size=4, + resi_connection='1conv', + use_crossattn=None): + super(RTFL, self).__init__() + + self.dim = dim + self.input_resolution = input_resolution + self.use_crossattn = use_crossattn + + self.residual_group = BasicLayer( + dim=dim, + input_resolution=input_resolution, + depth=depth, + num_heads=num_heads, + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path, + norm_layer=norm_layer, + downsample=downsample, + use_checkpoint=use_checkpoint, + use_crossattn=use_crossattn) + + if resi_connection == '1conv': + self.conv = nn.Conv2d(dim, dim, 3, 1, 1) + elif resi_connection == '3conv': + # to save parameters and memory + self.conv = nn.Sequential( + nn.Conv2d(dim, dim // 4, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(dim // 4, dim, 3, 1, 1)) + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=0, + embed_dim=dim, + norm_layer=None) + + self.patch_unembed = PatchUnEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=0, + embed_dim=dim, + norm_layer=None) + + def forward(self, x, x_size): + return self.patch_embed( + self.conv( + self.patch_unembed(self.residual_group(x, x_size), + x_size))) + x + + def flops(self): + flops = 0 + flops += self.residual_group.flops() + H, W = self.input_resolution + flops += H * W * self.dim * self.dim * 9 + flops += self.patch_embed.flops() + flops += self.patch_unembed.flops() + + return flops + + +class PatchEmbed(nn.Module): + r""" Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + x = x.flatten(2).transpose(1, 2).contiguous() # B Ph*Pw C + if self.norm is not None: + x = self.norm(x) + return x + + def flops(self): + flops = 0 + H, W = self.img_size + if self.norm is not None: + flops += H * W * self.embed_dim + return flops + + +class PatchUnEmbed(nn.Module): + r""" Image to Patch Unembedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + norm_layer=None): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], img_size[1] // patch_size[1] + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + def forward(self, x, x_size): + B, HW, C = x.shape + x = x.transpose(1, 2).contiguous().view(B, self.embed_dim, x_size[0], + x_size[1]) # B Ph*Pw C + return x + + def flops(self): + flops = 0 + return flops + + +class Upsample(nn.Sequential): + """Upsample module. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + """ + + def __init__(self, scale, num_feat): + m = [] + if (scale & (scale - 1)) == 0: # scale = 2^n + for _ in range(int(math.log(scale, 2))): + m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(2)) + elif scale == 3: + m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1)) + m.append(nn.PixelShuffle(3)) + else: + raise ValueError(f'scale {scale} is not supported. ' + 'Supported scales: 2^n and 3.') + super(Upsample, self).__init__(*m) + + +class UpsampleOneStep(nn.Sequential): + """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle) + Used in lightweight SR to save parameters. + + Args: + scale (int): Scale factor. Supported scales: 2^n and 3. + num_feat (int): Channel number of intermediate features. + + """ + + def __init__(self, scale, num_feat, num_out_ch, input_resolution=None): + self.num_feat = num_feat + self.input_resolution = input_resolution + m = [] + m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1)) + m.append(nn.PixelShuffle(scale)) + super(UpsampleOneStep, self).__init__(*m) + + def flops(self): + H, W = self.input_resolution + flops = H * W * self.num_feat * 3 * 9 + return flops diff --git a/modelscope/models/cv/video_frame_interpolation/utils/__init__.py b/modelscope/models/cv/video_frame_interpolation/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/cv/video_frame_interpolation/utils/scene_change_detection.py b/modelscope/models/cv/video_frame_interpolation/utils/scene_change_detection.py new file mode 100644 index 00000000..4cbe60a7 --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/utils/scene_change_detection.py @@ -0,0 +1,97 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def calc_hist(img_tensor): + hist = torch.histc(img_tensor, bins=64, min=0, max=255) + return hist / img_tensor.numel() + + +def do_scene_detect(F01_tensor, F10_tensor, img0_tensor, img1_tensor): + device = img0_tensor.device + scene_change = False + img0_tensor = img0_tensor.clone() + img1_tensor = img1_tensor.clone() + + img0_gray = 0.299 * img0_tensor[:, 0: + 1] + 0.587 * img0_tensor[:, 1: + 2] + 0.114 * img0_tensor[:, + 2: + 3] + img1_gray = 0.299 * img1_tensor[:, 0: + 1] + 0.587 * img1_tensor[:, 1: + 2] + 0.114 * img1_tensor[:, + 2: + 3] + img0_gray = torch.clamp(img0_gray, 0, 255).byte().float().cpu() + img1_gray = torch.clamp(img1_gray, 0, 255).byte().float().cpu() + + hist0 = calc_hist(img0_gray) + hist1 = calc_hist(img1_gray) + diff = torch.abs(hist0 - hist1) + diff[diff < 0.01] = 0 + if torch.sum(diff) > 0.8 or diff.max() > 0.4: + return True + img0_gray = img0_gray.to(device) + img1_gray = img1_gray.to(device) + + # second stage: detect mv and pix mismatch + + (n, c, h, w) = F01_tensor.size() + scale_x = w / 1920 + scale_y = h / 1080 + + # compare mv + (y, x) = torch.meshgrid(torch.arange(h), torch.arange(w)) + (y_grid, x_grid) = torch.meshgrid( + torch.arange(64, h - 64, 8), torch.arange(64, w - 64, 8)) + x = x.to(device) + y = y.to(device) + y_grid = y_grid.to(device) + x_grid = x_grid.to(device) + fx = F01_tensor[0, 0] + fy = F01_tensor[0, 1] + x_ = x.float() + fx + y_ = y.float() + fy + x_ = torch.clamp(x_ + 0.5, 0, w - 1).long() + y_ = torch.clamp(y_ + 0.5, 0, h - 1).long() + + grid_fx = fx[y_grid, x_grid] + grid_fy = fy[y_grid, x_grid] + + x_grid_ = x_[y_grid, x_grid] + y_grid_ = y_[y_grid, x_grid] + + grid_fx_ = F10_tensor[0, 0, y_grid_, x_grid_] + grid_fy_ = F10_tensor[0, 1, y_grid_, x_grid_] + + sum_x = grid_fx + grid_fx_ + sum_y = grid_fy + grid_fy_ + distance = torch.sqrt(sum_x**2 + sum_y**2) + + fx_len = torch.abs(grid_fx) * scale_x + fy_len = torch.abs(grid_fy) * scale_y + ori_len = torch.where(fx_len > fy_len, fx_len, fy_len) + + thres = torch.clamp(0.1 * ori_len + 4, 5, 14) + + # compare pix diff + ori_img = img0_gray + ref_img = img1_gray[:, :, y_, x_] + + img_diff = ori_img.float() - ref_img.float() + img_diff = torch.abs(img_diff) + + kernel = np.ones([8, 8], np.float) / 64 + kernel = torch.FloatTensor(kernel).to(device).unsqueeze(0).unsqueeze(0) + diff = F.conv2d(img_diff, kernel, padding=4) + + diff = diff[0, 0, y_grid, x_grid] + + index = (distance > thres) * (diff > 5) + if index.sum().float() / distance.numel() > 0.5: + scene_change = True + return scene_change diff --git a/modelscope/models/cv/video_frame_interpolation/utils/utils.py b/modelscope/models/cv/video_frame_interpolation/utils/utils.py new file mode 100644 index 00000000..68a8b99d --- /dev/null +++ b/modelscope/models/cv/video_frame_interpolation/utils/utils.py @@ -0,0 +1,96 @@ +# The implementation is adopted from RAFT, +# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT + +import numpy as np +import torch +import torch.nn.functional as F +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [ + pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, + pad_ht - pad_ht // 2 + ] + else: + self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self, x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata((x1, y1), + dx, (x0, y0), + method='nearest', + fill_value=0) + + flow_y = interpolate.griddata((x1, y1), + dy, (x0, y0), + method='nearest', + fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1, 1], dim=-1) + xgrid = 2 * xgrid / (W - 1) - 1 + ygrid = 2 * ygrid / (H - 1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid( + torch.arange(ht, device=device), torch.arange(wd, device=device)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate( + flow, size=new_size, mode=mode, align_corners=True) diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py b/modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py new file mode 100644 index 00000000..b9a338c1 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/video_frame_interpolation/__init__.py @@ -0,0 +1,23 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .video_frame_interpolation_dataset import VideoFrameInterpolationDataset + +else: + _import_structure = { + 'video_frame_interpolation_dataset': + ['VideoFrameInterpolationDataset'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py b/modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py new file mode 100644 index 00000000..ae876b18 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/video_frame_interpolation/data_utils.py @@ -0,0 +1,41 @@ +# ------------------------------------------------------------------------ +# Modified from BasicSR (https://github.com/xinntao/BasicSR) +# Copyright 2018-2020 BasicSR Authors +# ------------------------------------------------------------------------ + +import cv2 +import torch +import torch.nn.functional as F + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def img_padding(img_tensor, height, width, pad_num=32): + ph = ((height - 1) // pad_num + 1) * pad_num + pw = ((width - 1) // pad_num + 1) * pad_num + padding = (0, pw - width, 0, ph - height) + img_tensor = F.pad(img_tensor, padding) + return img_tensor diff --git a/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py b/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py new file mode 100644 index 00000000..44b965a7 --- /dev/null +++ b/modelscope/msdatasets/task_datasets/video_frame_interpolation/video_frame_interpolation_dataset.py @@ -0,0 +1,54 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from collections import defaultdict + +import cv2 +import numpy as np +import torch + +from modelscope.metainfo import Models +from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS +from modelscope.msdatasets.task_datasets.torch_base_dataset import \ + TorchTaskDataset +from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import ( + img2tensor, img_padding) +from modelscope.utils.constant import Tasks + + +def default_loader(path): + return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32) + + +@TASK_DATASETS.register_module( + Tasks.video_frame_interpolation, + module_name=Models.video_frame_interpolation) +class VideoFrameInterpolationDataset(TorchTaskDataset): + """Dataset for video frame-interpolation. + """ + + def __init__(self, dataset, opt): + self.dataset = dataset + self.opt = opt + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + + # Load frames. Dimension order: HWC; channel order: BGR; + # image range: [0, 1], float32 + item_dict = self.dataset[index] + img0 = default_loader(item_dict['Input1:FILE']) + img1 = default_loader(item_dict['Input2:FILE']) + img2 = default_loader(item_dict['Input3:FILE']) + img3 = default_loader(item_dict['Input4:FILE']) + gt = default_loader(item_dict['Output:FILE']) + + img0, img1, img2, img3, gt = img2tensor([img0, img1, img2, img3, gt], + bgr2rgb=False, + float32=True) + + imgs = torch.cat((img0, img1, img2, img3), dim=0) + height, width = imgs.size(1), imgs.size(2) + imgs = img_padding(imgs, height, width, pad_num=32) + return {'input': imgs, 'target': gt / 255.0} diff --git a/modelscope/outputs/outputs.py b/modelscope/outputs/outputs.py index 2cd098bb..7a71d789 100644 --- a/modelscope/outputs/outputs.py +++ b/modelscope/outputs/outputs.py @@ -305,6 +305,7 @@ TASK_OUTPUTS = { # video editing task result for a single video # {"output_video": "path_to_rendered_video"} + Tasks.video_frame_interpolation: [OutputKeys.OUTPUT_VIDEO], Tasks.video_super_resolution: [OutputKeys.OUTPUT_VIDEO], # live category recognition result for single video diff --git a/modelscope/pipelines/builder.py b/modelscope/pipelines/builder.py index f3629469..05859683 100644 --- a/modelscope/pipelines/builder.py +++ b/modelscope/pipelines/builder.py @@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { 'damo/cv_video-inpainting'), Tasks.video_human_matting: (Pipelines.video_human_matting, 'damo/cv_effnetv2_video-human-matting'), + Tasks.video_frame_interpolation: + (Pipelines.video_frame_interpolation, + 'damo/cv_raft_video-frame-interpolation'), Tasks.human_wholebody_keypoint: (Pipelines.human_wholebody_keypoint, 'damo/cv_hrnetw48_human-wholebody-keypoint_image'), @@ -247,9 +250,9 @@ DEFAULT_MODEL_FOR_PIPELINE = { Tasks.translation_evaluation: (Pipelines.translation_evaluation, 'damo/nlp_unite_mup_translation_evaluation_multilingual_large'), - Tasks.video_object_segmentation: - (Pipelines.video_object_segmentation, - 'damo/cv_rdevos_video-object-segmentation'), + Tasks.video_object_segmentation: ( + Pipelines.video_object_segmentation, + 'damo/cv_rdevos_video-object-segmentation'), Tasks.video_multi_object_tracking: ( Pipelines.video_multi_object_tracking, 'damo/cv_yolov5_video-multi-object-tracking_fairmot'), diff --git a/modelscope/pipelines/cv/__init__.py b/modelscope/pipelines/cv/__init__.py index 0a77d364..3358c961 100644 --- a/modelscope/pipelines/cv/__init__.py +++ b/modelscope/pipelines/cv/__init__.py @@ -68,6 +68,7 @@ if TYPE_CHECKING: from .hand_static_pipeline import HandStaticPipeline from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline + from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline from .image_skychange_pipeline import ImageSkychangePipeline from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline from .video_stabilization_pipeline import VideoStabilizationPipeline @@ -164,6 +165,9 @@ else: 'language_guided_video_summarization_pipeline': [ 'LanguageGuidedVideoSummarizationPipeline' ], + 'video_frame_interpolation_pipeline': [ + 'VideoFrameInterpolationPipeline' + ], 'image_skychange_pipeline': ['ImageSkychangePipeline'], 'video_object_segmentation_pipeline': [ 'VideoObjectSegmentationPipeline' diff --git a/modelscope/pipelines/cv/video_frame_interpolation_pipeline.py b/modelscope/pipelines/cv/video_frame_interpolation_pipeline.py new file mode 100644 index 00000000..d47340cb --- /dev/null +++ b/modelscope/pipelines/cv/video_frame_interpolation_pipeline.py @@ -0,0 +1,613 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import glob +import math +import os +import os.path as osp +import subprocess +import tempfile +from typing import Any, Dict, Optional, Union + +import cv2 +import numpy as np +import torch +import torch.nn.functional as F +from torchvision.utils import make_grid + +from modelscope.metainfo import Pipelines +from modelscope.models.cv.video_frame_interpolation.utils.scene_change_detection import \ + do_scene_detect +from modelscope.models.cv.video_frame_interpolation.VFINet_for_video_frame_interpolation import \ + VFINetForVideoFrameInterpolation +from modelscope.outputs import OutputKeys +from modelscope.pipelines.base import Input, Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import LoadImage +from modelscope.preprocessors.cv import VideoReader +from modelscope.utils.config import Config +from modelscope.utils.constant import ModelFile, Tasks +from modelscope.utils.logger import get_logger + +VIDEO_EXTENSIONS = ('.mp4', '.mov') +logger = get_logger() + + +def img_trans(img_tensor): # in format of RGB + img_tensor = img_tensor / 255.0 + mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, + 1).type_as(img_tensor) + img_tensor -= mean + return img_tensor + + +def add_mean(x): + mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x) + return x + mean + + +def img_padding(img_tensor, height, width, pad_num=32): + ph = ((height - 1) // pad_num + 1) * pad_num + pw = ((width - 1) // pad_num + 1) * pad_num + padding = (0, pw - width, 0, ph - height) + img_tensor = F.pad(img_tensor, padding) + return img_tensor + + +def do_inference_lowers(flow_10, + flow_12, + flow_21, + flow_23, + img1, + img2, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=False): + # given frame1, frame2 and optical flow, predict frame_t + if start_end_flag: + read_count -= 1 + else: + read_count -= 2 + while inter_count <= read_count: + t = inter_count + 1 - read_count + t = round(t, 2) + if (t - 0) < delta / 2: + output = img1 + elif (1 - t) < delta / 2: + output = img2 + else: + output = inter_model(flow_10, flow_12, flow_21, flow_23, img1, + img2, t) + + output = 255 * add_mean(output) + outputs.append(output) + inter_count += delta + + return outputs, inter_count + + +def do_inference_highers(flow_10, + flow_12, + flow_21, + flow_23, + img1, + img2, + img1_up, + img2_up, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=False): + # given frame1, frame2 and optical flow, predict frame_t. For videos with a resolution of 2k and above + if start_end_flag: + read_count -= 1 + else: + read_count -= 2 + while inter_count <= read_count: + t = inter_count + 1 - read_count + t = round(t, 2) + if (t - 0) < delta / 2: + output = img1_up + elif (1 - t) < delta / 2: + output = img2_up + else: + output = inter_model(flow_10, flow_12, flow_21, flow_23, img1, + img2, img1_up, img2_up, t) + + output = 255 * add_mean(output) + outputs.append(output) + inter_count += delta + + return outputs, inter_count + + +def inference_lowers(flow_model, refine_model, inter_model, video_len, + read_count, inter_count, delta, scene_change_flag, + img_tensor_list, img_ori_list, inputs, outputs): + # given a video with a resolution less than 2k and output fps, execute the video frame interpolation function. + height, width = inputs[read_count].size(2), inputs[read_count].size(3) + # We use four consecutive frames to do frame interpolation. flow_10 represents + # optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and + # flow_23. + flow_10 = None + flow_12 = None + flow_21 = None + flow_23 = None + with torch.no_grad(): + while (read_count < video_len): + img = inputs[read_count] + img = img_padding(img, height, width) + img_ori_list.append(img) + img_tensor_list.append(img_trans(img)) + read_count += 1 + if len(img_tensor_list) == 2: + img0 = img_tensor_list[0] + img1 = img_tensor_list[1] + img0_ori = img_ori_list[0] + img1_ori = img_ori_list[1] + _, flow_01_up = flow_model( + img0_ori, img1_ori, iters=12, test_mode=True) + _, flow_10_up = flow_model( + img1_ori, img0_ori, iters=12, test_mode=True) + flow_01, flow_10 = refine_model(img0, img1, flow_01_up, + flow_10_up, 2) + scene_change_flag[0] = do_scene_detect( + flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height, + 0:width], + img_ori_list[0][:, :, 0:height, 0:width], + img_ori_list[1][:, :, 0:height, 0:width]) + if scene_change_flag[0]: + outputs, inter_count = do_inference_lowers( + None, + None, + None, + None, + img0, + img1, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + else: + outputs, inter_count = do_inference_lowers( + None, + flow_01, + flow_10, + None, + img0, + img1, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + + if len(img_tensor_list) == 4: + if flow_12 is None or flow_21 is None: + img2 = img_tensor_list[2] + img2_ori = img_ori_list[2] + _, flow_12_up = flow_model( + img1_ori, img2_ori, iters=12, test_mode=True) + _, flow_21_up = flow_model( + img2_ori, img1_ori, iters=12, test_mode=True) + flow_12, flow_21 = refine_model(img1, img2, flow_12_up, + flow_21_up, 2) + scene_change_flag[1] = do_scene_detect( + flow_12[:, :, 0:height, + 0:width], flow_21[:, :, 0:height, 0:width], + img_ori_list[1][:, :, 0:height, 0:width], + img_ori_list[2][:, :, 0:height, 0:width]) + + img3 = img_tensor_list[3] + img3_ori = img_ori_list[3] + _, flow_23_up = flow_model( + img2_ori, img3_ori, iters=12, test_mode=True) + _, flow_32_up = flow_model( + img3_ori, img2_ori, iters=12, test_mode=True) + flow_23, flow_32 = refine_model(img2, img3, flow_23_up, + flow_32_up, 2) + scene_change_flag[2] = do_scene_detect( + flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height, + 0:width], + img_ori_list[2][:, :, 0:height, 0:width], + img_ori_list[3][:, :, 0:height, 0:width]) + + if scene_change_flag[1]: + outputs, inter_count = do_inference_lowers( + None, None, None, None, img1, img2, inter_model, + read_count, inter_count, delta, outputs) + elif scene_change_flag[0] or scene_change_flag[2]: + outputs, inter_count = do_inference_lowers( + None, flow_12, flow_21, None, img1, img2, inter_model, + read_count, inter_count, delta, outputs) + else: + outputs, inter_count = do_inference_lowers( + flow_10_up, flow_12, flow_21, flow_23_up, img1, img2, + inter_model, read_count, inter_count, delta, outputs) + + img_tensor_list.pop(0) + img_ori_list.pop(0) + + # for next group + img1 = img2 + img2 = img3 + img1_ori = img2_ori + img2_ori = img3_ori + flow_10 = flow_21 + flow_12 = flow_23 + flow_21 = flow_32 + + flow_10_up = flow_21_up + flow_12_up = flow_23_up + flow_21_up = flow_32_up + + # save scene change flag for next group + scene_change_flag[0] = scene_change_flag[1] + scene_change_flag[1] = scene_change_flag[2] + scene_change_flag[2] = False + + if read_count > 0: # the last remaining 3 images + img_ori_list.pop(0) + img_tensor_list.pop(0) + assert (len(img_tensor_list) == 2) + + if scene_change_flag[1]: + outputs, inter_count = do_inference_lowers( + None, + None, + None, + None, + img1, + img2, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + else: + outputs, inter_count = do_inference_lowers( + None, + flow_12, + flow_21, + None, + img1, + img2, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + + return outputs + + +def inference_highers(flow_model, refine_model, inter_model, video_len, + read_count, inter_count, delta, scene_change_flag, + img_tensor_list, img_ori_list, inputs, outputs): + # given a video with a resolution of 2k or above and output fps, execute the video frame interpolation function. + if inputs[read_count].size(2) % 2 != 0 or inputs[read_count].size( + 3) % 2 != 0: + raise RuntimeError('Video width and height must be even') + + height, width = inputs[read_count].size(2) // 2, inputs[read_count].size( + 3) // 2 + # We use four consecutive frames to do frame interpolation. flow_10 represents + # optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and + # flow_23. + flow_10 = None + flow_12 = None + flow_21 = None + flow_23 = None + img_up_list = [] + with torch.no_grad(): + while (read_count < video_len): + img_up = inputs[read_count] + img_up = img_padding(img_up, height * 2, width * 2, pad_num=64) + img = F.interpolate( + img_up, scale_factor=0.5, mode='bilinear', align_corners=False) + + img_up_list.append(img_trans(img_up)) + img_ori_list.append(img) + img_tensor_list.append(img_trans(img)) + read_count += 1 + if len(img_tensor_list) == 2: + img0 = img_tensor_list[0] + img1 = img_tensor_list[1] + img0_ori = img_ori_list[0] + img1_ori = img_ori_list[1] + img0_up = img_up_list[0] + img1_up = img_up_list[1] + _, flow_01_up = flow_model( + img0_ori, img1_ori, iters=12, test_mode=True) + _, flow_10_up = flow_model( + img1_ori, img0_ori, iters=12, test_mode=True) + flow_01, flow_10 = refine_model(img0, img1, flow_01_up, + flow_10_up, 2) + scene_change_flag[0] = do_scene_detect( + flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height, + 0:width], + img_ori_list[0][:, :, 0:height, 0:width], + img_ori_list[1][:, :, 0:height, 0:width]) + if scene_change_flag[0]: + outputs, inter_count = do_inference_highers( + None, + None, + None, + None, + img0, + img1, + img0_up, + img1_up, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + else: + outputs, inter_count = do_inference_highers( + None, + flow_01, + flow_10, + None, + img0, + img1, + img0_up, + img1_up, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + + if len(img_tensor_list) == 4: + if flow_12 is None or flow_21 is None: + img2 = img_tensor_list[2] + img2_ori = img_ori_list[2] + img2_up = img_up_list[2] + _, flow_12_up = flow_model( + img1_ori, img2_ori, iters=12, test_mode=True) + _, flow_21_up = flow_model( + img2_ori, img1_ori, iters=12, test_mode=True) + flow_12, flow_21 = refine_model(img1, img2, flow_12_up, + flow_21_up, 2) + scene_change_flag[1] = do_scene_detect( + flow_12[:, :, 0:height, + 0:width], flow_21[:, :, 0:height, 0:width], + img_ori_list[1][:, :, 0:height, 0:width], + img_ori_list[2][:, :, 0:height, 0:width]) + + img3 = img_tensor_list[3] + img3_ori = img_ori_list[3] + img3_up = img_up_list[3] + _, flow_23_up = flow_model( + img2_ori, img3_ori, iters=12, test_mode=True) + _, flow_32_up = flow_model( + img3_ori, img2_ori, iters=12, test_mode=True) + flow_23, flow_32 = refine_model(img2, img3, flow_23_up, + flow_32_up, 2) + scene_change_flag[2] = do_scene_detect( + flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height, + 0:width], + img_ori_list[2][:, :, 0:height, 0:width], + img_ori_list[3][:, :, 0:height, 0:width]) + + if scene_change_flag[1]: + outputs, inter_count = do_inference_highers( + None, None, None, None, img1, img2, img1_up, img2_up, + inter_model, read_count, inter_count, delta, outputs) + elif scene_change_flag[0] or scene_change_flag[2]: + outputs, inter_count = do_inference_highers( + None, flow_12, flow_21, None, img1, img2, img1_up, + img2_up, inter_model, read_count, inter_count, delta, + outputs) + else: + outputs, inter_count = do_inference_highers( + flow_10_up, flow_12, flow_21, flow_23_up, img1, img2, + img1_up, img2_up, inter_model, read_count, inter_count, + delta, outputs) + + img_up_list.pop(0) + img_tensor_list.pop(0) + img_ori_list.pop(0) + + # for next group + img1 = img2 + img2 = img3 + img1_ori = img2_ori + img2_ori = img3_ori + img1_up = img2_up + img2_up = img3_up + flow_10 = flow_21 + flow_12 = flow_23 + flow_21 = flow_32 + + flow_10_up = flow_21_up + flow_12_up = flow_23_up + flow_21_up = flow_32_up + + # save scene change flag for next group + scene_change_flag[0] = scene_change_flag[1] + scene_change_flag[1] = scene_change_flag[2] + scene_change_flag[2] = False + + if read_count > 0: # the last remaining 3 images + img_ori_list.pop(0) + img_tensor_list.pop(0) + assert (len(img_tensor_list) == 2) + + if scene_change_flag[1]: + outputs, inter_count = do_inference_highers( + None, + None, + None, + None, + img1, + img2, + img1_up, + img2_up, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + else: + outputs, inter_count = do_inference_highers( + None, + flow_12, + flow_21, + None, + img1, + img2, + img1_up, + img2_up, + inter_model, + read_count, + inter_count, + delta, + outputs, + start_end_flag=True) + + return outputs + + +def convert(param): + return { + k.replace('module.', ''): v + for k, v in param.items() if 'module.' in k + } + + +__all__ = ['VideoFrameInterpolationPipeline'] + + +@PIPELINES.register_module( + Tasks.video_frame_interpolation, + module_name=Pipelines.video_frame_interpolation) +class VideoFrameInterpolationPipeline(Pipeline): + """ Video Frame Interpolation Pipeline. + Example: + ```python + >>> from modelscope.pipelines import pipeline + >>> from modelscope.utils.constant import Tasks + >>> from modelscope.outputs import OutputKeys + + >>> video = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_frame_interpolation_test.mp4' + >>> video_frame_interpolation_pipeline = pipeline(Tasks.video_frame_interpolation, + 'damo/cv_raft_video-frame-interpolation') + >>> result = video_frame_interpolation_pipeline(video)[OutputKeys.OUTPUT_VIDEO] + >>> print('pipeline: the output video path is {}'.format(result)) + ``` + """ + + def __init__(self, + model: Union[VFINetForVideoFrameInterpolation, str], + preprocessor=None, + **kwargs): + super().__init__(model=model, preprocessor=preprocessor, **kwargs) + if torch.cuda.is_available(): + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self.net = self.model.model + self.net.to(self._device) + self.net.eval() + logger.info('load video frame-interpolation done') + + def preprocess(self, input: Input, out_fps: float = 0) -> Dict[str, Any]: + # read images + file_extension = os.path.splitext(input)[1] + if file_extension in VIDEO_EXTENSIONS: # input is a video file + video_reader = VideoReader(input) + inputs = [] + for frame in video_reader: + inputs.append(frame) + fps = video_reader.fps + elif file_extension == '': # input is a directory + inputs = [] + input_paths = sorted(glob.glob(f'{input}/*')) + for input_path in input_paths: + img = LoadImage(input_path, mode='rgb') + inputs.append(img) + fps = 25 # default fps + else: + raise ValueError('"input" can only be a video or a directory.') + + for i, img in enumerate(inputs): + img = torch.from_numpy(img.copy()).permute(2, 0, 1).float() + inputs[i] = img.unsqueeze(0) + + if out_fps == 0: + out_fps = 2 * fps + return {'video': inputs, 'fps': fps, 'out_fps': out_fps} + + def forward(self, input: Dict[str, Any]) -> Dict[str, Any]: + inputs = input['video'] + fps = input['fps'] + out_fps = input['out_fps'] + video_len = len(inputs) + + flow_model = self.net.flownet + refine_model = self.net.internet.ifnet + + read_count = 0 + inter_count = 0 + delta = fps / out_fps + scene_change_flag = [False, False, False] + img_tensor_list = [] + img_ori_list = [] + outputs = [] + height, width = inputs[read_count].size(2), inputs[read_count].size(3) + if height >= 1440 or width >= 2560: + inter_model = self.net.internet_Ds.internet + outputs = inference_highers(flow_model, refine_model, inter_model, + video_len, read_count, inter_count, + delta, scene_change_flag, + img_tensor_list, img_ori_list, inputs, + outputs) + else: + inter_model = self.net.internet.internet + outputs = inference_lowers(flow_model, refine_model, inter_model, + video_len, read_count, inter_count, + delta, scene_change_flag, + img_tensor_list, img_ori_list, inputs, + outputs) + + for i in range(len(outputs)): + outputs[i] = outputs[i][:, :, 0:height, 0:width] + return {'output': outputs, 'fps': out_fps} + + def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]: + output_video_path = kwargs.get('output_video', None) + demo_service = kwargs.get('demo_service', True) + if output_video_path is None: + output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name + h, w = inputs['output'][0].shape[-2:] + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + video_writer = cv2.VideoWriter(output_video_path, fourcc, + inputs['fps'], (w, h)) + for i in range(len(inputs['output'])): + img = inputs['output'][i] + img = img[0].permute(1, 2, 0).byte().cpu().numpy() + video_writer.write(img.astype(np.uint8)) + + video_writer.release() + if demo_service: + assert os.system( + 'ffmpeg -version') == 0, 'ffmpeg is not installed correctly!' + output_video_path_for_web = output_video_path[:-4] + '_web.mp4' + convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}' + subprocess.call(convert_cmd, shell=True) + return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web} + else: + return {OutputKeys.OUTPUT_VIDEO: output_video_path} diff --git a/modelscope/utils/constant.py b/modelscope/utils/constant.py index 6bb9a142..368f0b3e 100644 --- a/modelscope/utils/constant.py +++ b/modelscope/utils/constant.py @@ -98,6 +98,7 @@ class CVTasks(object): # video editing video_inpainting = 'video-inpainting' + video_frame_interpolation = 'video-frame-interpolation' video_stabilization = 'video-stabilization' video_super_resolution = 'video-super-resolution' diff --git a/tests/pipelines/test_video_frame_interpolation.py b/tests/pipelines/test_video_frame_interpolation.py new file mode 100644 index 00000000..951da2b9 --- /dev/null +++ b/tests/pipelines/test_video_frame_interpolation.py @@ -0,0 +1,52 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import sys +import unittest + +from modelscope.hub.snapshot_download import snapshot_download +from modelscope.models import Model +from modelscope.outputs import OutputKeys +from modelscope.pipelines import pipeline +from modelscope.pipelines.cv import VideoFrameInterpolationPipeline +from modelscope.utils.constant import Tasks +from modelscope.utils.demo_utils import DemoCompatibilityCheck +from modelscope.utils.test_utils import test_level + + +class VideoFrameInterpolationTest(unittest.TestCase, DemoCompatibilityCheck): + + def setUp(self) -> None: + self.task = Tasks.video_frame_interpolation + self.model_id = 'damo/cv_raft_video-frame-interpolation' + self.test_video = 'data/test/videos/video_frame_interpolation_test.mp4' + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_by_direct_model_download(self): + cache_path = snapshot_download(self.model_id) + pipeline = VideoFrameInterpolationPipeline(cache_path) + pipeline.group_key = self.task + out_video_path = pipeline( + input=self.test_video)[OutputKeys.OUTPUT_VIDEO] + print('pipeline: the output video path is {}'.format(out_video_path)) + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_run_with_model_from_modelhub(self): + pipeline_ins = pipeline( + task=Tasks.video_frame_interpolation, model=self.model_id) + out_video_path = pipeline_ins( + input=self.test_video)[OutputKeys.OUTPUT_VIDEO] + print('pipeline: the output video path is {}'.format(out_video_path)) + + @unittest.skipUnless(test_level() >= 2, 'skip test in current test level') + def test_run_with_default_model(self): + pipeline_ins = pipeline(task=Tasks.video_frame_interpolation) + out_video_path = pipeline_ins( + input=self.test_video)[OutputKeys.OUTPUT_VIDEO] + print('pipeline: the output video path is {}'.format(out_video_path)) + + @unittest.skip('demo compatibility test is only enabled on a needed-basis') + def test_demo_compatibility(self): + self.compatibility_check() + + +if __name__ == '__main__': + unittest.main()