submit video frame interpolation model

增加视频插帧模型
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11188339
This commit is contained in:
liaojie.laj
2023-01-10 06:57:19 +08:00
committed by yingda.chen
parent 338a5a4994
commit fcf6e6431f
32 changed files with 4196 additions and 6 deletions

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:e97ff88d0af12f7dd3ef04ce50b87b51ffbb9a57dce81d2d518df4abd2fdb826
size 3231793

View File

@@ -63,6 +63,7 @@ class Models(object):
image_body_reshaping = 'image-body-reshaping'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_stabilization = 'video-stabilization'
real_basicvsr = 'real-basicvsr'
@@ -270,6 +271,7 @@ class Pipelines(object):
referring_video_object_segmentation = 'referring-video-object-segmentation'
image_skychange = 'image-skychange'
video_human_matting = 'video-human-matting'
video_frame_interpolation = 'video-frame-interpolation'
video_object_segmentation = 'video-object-segmentation'
video_stabilization = 'video-stabilization'
video_super_resolution = 'realbasicvsr-video-super-resolution'
@@ -507,6 +509,8 @@ class Metrics(object):
# metrics for image denoise task
image_denoise_metric = 'image-denoise-metric'
# metrics for video frame-interpolation task
video_frame_interpolation_metric = 'video-frame-interpolation-metric'
# metrics for real-world video super-resolution task
video_super_resolution_metric = 'video-super-resolution-metric'

View File

@@ -21,6 +21,7 @@ if TYPE_CHECKING:
from .bleu_metric import BleuMetric
from .image_inpainting_metric import ImageInpaintingMetric
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
from .video_frame_interpolation_metric import VideoFrameInterpolationMetric
from .video_stabilization_metric import VideoStabilizationMetric
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
from .ppl_metric import PplMetric
@@ -46,6 +47,7 @@ else:
'bleu_metric': ['BleuMetric'],
'referring_video_object_segmentation_metric':
['ReferringVideoObjectSegmentationMetric'],
'video_frame_interpolation_metric': ['VideoFrameInterpolationMetric'],
'video_stabilization_metric': ['VideoStabilizationMetric'],
'ppl_metric': ['PplMetric'],
}

View File

@@ -16,6 +16,7 @@ class MetricKeys(object):
RECALL = 'recall'
PSNR = 'psnr'
SSIM = 'ssim'
LPIPS = 'lpips'
NIQE = 'niqe'
AVERAGE_LOSS = 'avg_loss'
FScore = 'fscore'
@@ -53,6 +54,8 @@ task_default_metrics = {
Tasks.image_inpainting: [Metrics.image_inpainting_metric],
Tasks.referring_video_object_segmentation:
[Metrics.referring_video_object_segmentation_metric],
Tasks.video_frame_interpolation:
[Metrics.video_frame_interpolation_metric],
Tasks.video_stabilization: [Metrics.video_stabilization_metric],
}

View File

@@ -0,0 +1,172 @@
# ------------------------------------------------------------------------
# Copyright (c) Alibaba, Inc. and its affiliates.
# ------------------------------------------------------------------------
import math
from math import exp
from typing import Dict
import lpips
import numpy as np
import torch
import torch.nn.functional as F
from modelscope.metainfo import Metrics
from modelscope.metrics.base import Metric
from modelscope.metrics.builder import METRICS, MetricKeys
from modelscope.utils.registry import default_group
@METRICS.register_module(
group_key=default_group,
module_name=Metrics.video_frame_interpolation_metric)
class VideoFrameInterpolationMetric(Metric):
"""The metric computation class for video frame interpolation,
which will return PSNR, SSIM and LPIPS.
"""
pred_name = 'pred'
label_name = 'target'
def __init__(self):
super(VideoFrameInterpolationMetric, self).__init__()
self.preds = []
self.labels = []
self.loss_fn_alex = lpips.LPIPS(net='alex').cuda()
def add(self, outputs: Dict, inputs: Dict):
ground_truths = outputs[VideoFrameInterpolationMetric.label_name]
eval_results = outputs[VideoFrameInterpolationMetric.pred_name]
self.preds.append(eval_results)
self.labels.append(ground_truths)
def evaluate(self):
psnr_list, ssim_list, lpips_list = [], [], []
with torch.no_grad():
for (pred, label) in zip(self.preds, self.labels):
# norm to 0-1
height, width = label.size(2), label.size(3)
pred = pred[:, :, 0:height, 0:width]
psnr_list.append(calculate_psnr(label, pred))
ssim_list.append(calculate_ssim(label, pred))
lpips_list.append(
calculate_lpips(label, pred, self.loss_fn_alex))
return {
MetricKeys.PSNR: np.mean(psnr_list),
MetricKeys.SSIM: np.mean(ssim_list),
MetricKeys.LPIPS: np.mean(lpips_list)
}
def gaussian(window_size, sigma):
gauss = torch.Tensor([
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
for x in range(window_size)
])
return gauss / gauss.sum()
def create_window_3d(window_size, channel=1, device=None):
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
_2D_window = _1D_window.mm(_1D_window.t())
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
window = _3D_window.expand(1, channel, window_size, window_size,
window_size).contiguous().to(device)
return window
def calculate_psnr(img1, img2):
psnr = -10 * math.log10(
torch.mean((img1[0] - img2[0]) * (img1[0] - img2[0])).cpu().data)
return psnr
def calculate_ssim(img1,
img2,
window_size=11,
window=None,
size_average=True,
full=False,
val_range=None):
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if val_range is None:
if torch.max(img1) > 128:
max_val = 255
else:
max_val = 1
if torch.min(img1) < -0.5:
min_val = -1
else:
min_val = 0
L = max_val - min_val
else:
L = val_range
padd = 0
(_, _, height, width) = img1.size()
if window is None:
real_size = min(window_size, height, width)
window = create_window_3d(
real_size, channel=1, device=img1.device).to(img1.device)
# Channel is set to 1 since we consider color images as volumetric images
img1 = img1.unsqueeze(1)
img2 = img2.unsqueeze(1)
mu1 = F.conv3d(
F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'),
window,
padding=padd,
groups=1)
mu2 = F.conv3d(
F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'),
window,
padding=padd,
groups=1)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = F.conv3d(
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'),
window,
padding=padd,
groups=1) - mu1_sq
sigma2_sq = F.conv3d(
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
window,
padding=padd,
groups=1) - mu2_sq
sigma12 = F.conv3d(
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
window,
padding=padd,
groups=1) - mu1_mu2
C1 = (0.01 * L)**2
C2 = (0.03 * L)**2
v1 = 2.0 * sigma12 + C2
v2 = sigma1_sq + sigma2_sq + C2
cs = torch.mean(v1 / v2) # contrast sensitivity
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
if size_average:
ret = ssim_map.mean()
else:
ret = ssim_map.mean(1).mean(1).mean(1)
if full:
return ret, cs
return ret.cpu()
def calculate_lpips(img1, img2, loss_fn_alex):
img1 = img1 * 2 - 1
img2 = img2 * 2 - 1
d = loss_fn_alex(img1, img2)
return d.cpu().item()

View File

@@ -15,8 +15,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
pointcloud_sceneflow_estimation, product_retrieval_embedding,
realtime_object_detection, referring_video_object_segmentation,
salient_detection, shop_segmentation, super_resolution,
video_object_segmentation, video_single_object_tracking,
video_stabilization, video_summarization,
video_super_resolution, virual_tryon)
video_frame_interpolation, video_object_segmentation,
video_single_object_tracking, video_stabilization,
video_summarization, video_super_resolution, virual_tryon)
# yapf: enable

View File

@@ -0,0 +1,53 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
from modelscope.models.cv.video_frame_interpolation.flow_model.raft import RAFT
from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \
IFNet
from modelscope.models.cv.video_frame_interpolation.interp_model.refinenet_arch import (
InterpNet, InterpNetDs)
class VFINet(nn.Module):
def __init__(self, args, Ds_flag=False):
super(VFINet, self).__init__()
self.flownet = RAFT(args)
self.internet = InterpNet()
if Ds_flag:
self.internet_Ds = InterpNetDs()
def img_trans(self, img_tensor): # in format of RGB
img_tensor = img_tensor / 255.0
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1,
1).type_as(img_tensor)
img_tensor -= mean
return img_tensor
def add_mean(self, x):
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x)
return x + mean
def forward(self, imgs, timestep=0.5):
self.flownet.eval()
self.internet.eval()
with torch.no_grad():
img0 = imgs[:, :3]
img1 = imgs[:, 3:6]
img2 = imgs[:, 6:9]
img3 = imgs[:, 9:12]
_, F10_up = self.flownet(img1, img0, iters=12, test_mode=True)
_, F12_up = self.flownet(img1, img2, iters=12, test_mode=True)
_, F21_up = self.flownet(img2, img1, iters=12, test_mode=True)
_, F23_up = self.flownet(img2, img3, iters=12, test_mode=True)
img1 = self.img_trans(img1.clone())
img2 = self.img_trans(img2.clone())
It_warp = self.internet(
img1, img2, F10_up, F12_up, F21_up, F23_up, timestep=timestep)
It_warp = self.add_mean(It_warp)
return It_warp

View File

@@ -0,0 +1,98 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from copy import deepcopy
from typing import Any, Dict, Union
import torch.cuda
import torch.nn.functional as F
from torch.nn.parallel import DataParallel, DistributedDataParallel
from modelscope.metainfo import Models
from modelscope.models.base import Tensor
from modelscope.models.base.base_torch_model import TorchModel
from modelscope.models.builder import MODELS
# from modelscope.models.cv.video_super_resolution.common import charbonnier_loss
from modelscope.models.cv.video_frame_interpolation.VFINet_arch import VFINet
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
logger = get_logger()
__all__ = ['VFINetForVideoFrameInterpolation']
def convert(param):
return {
k.replace('module.', ''): v
for k, v in param.items() if 'module.' in k
}
@MODELS.register_module(
Tasks.video_frame_interpolation,
module_name=Models.video_frame_interpolation)
class VFINetForVideoFrameInterpolation(TorchModel):
def __init__(self, model_dir: str, *args, **kwargs):
"""initialize the video frame-interpolation model from the `model_dir` path.
Args:
model_dir (str): the model path.
"""
super().__init__(model_dir, *args, **kwargs)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.model_dir = model_dir
self.config = Config.from_file(
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
flownet_path = os.path.join(model_dir, 'raft-sintel.pt')
internet_path = os.path.join(model_dir, 'interpnet.pt')
self.model = VFINet(self.config.model.network, Ds_flag=True)
self._load_pretrained(flownet_path, internet_path)
def _load_pretrained(self, flownet_path, internet_path):
state_dict_flownet = torch.load(
flownet_path, map_location=self._device)
state_dict_internet = torch.load(
internet_path, map_location=self._device)
self.model.flownet.load_state_dict(
convert(state_dict_flownet), strict=True)
self.model.internet.load_state_dict(
convert(state_dict_internet), strict=True)
self.model.internet_Ds.load_state_dict(
convert(state_dict_internet), strict=True)
logger.info('load model done.')
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
return {'output': self.model(input)}
def _evaluate_postprocess(self, input: Tensor,
target: Tensor) -> Dict[str, list]:
preds = self.model(input)
del input
torch.cuda.empty_cache()
return {'pred': preds, 'target': target}
def forward(self, inputs: Dict[str,
Tensor]) -> Dict[str, Union[list, Tensor]]:
"""return the result by the model
Args:
inputs (Tensor): the preprocessed data
Returns:
Dict[str, Tensor]: results
"""
if 'target' in inputs:
return self._evaluate_postprocess(**inputs)
else:
return self._inference_forward(**inputs)

View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .VFINet_arch import VFINet
else:
_import_structure = {'VFINet_arch': ['VFINet']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,92 @@
# The implementation is adopted from RAFT,
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
import torch
import torch.nn.functional as F
from modelscope.models.cv.video_frame_interpolation.utils.utils import (
bilinear_sampler, coords_grid)
class CorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.corr_pyramid = []
# all pairs correlation
corr = CorrBlock.corr(fmap1, fmap2)
batch, h1, w1, dim, h2, w2 = corr.shape
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
self.corr_pyramid.append(corr)
for i in range(self.num_levels - 1):
corr = F.avg_pool2d(corr, 2, stride=2)
self.corr_pyramid.append(corr)
def __call__(self, coords):
r = self.radius
coords = coords.permute(0, 2, 3, 1)
batch, h1, w1, _ = coords.shape
out_pyramid = []
for i in range(self.num_levels):
corr = self.corr_pyramid[i]
dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
coords_lvl = centroid_lvl + delta_lvl
corr = bilinear_sampler(corr, coords_lvl)
corr = corr.view(batch, h1, w1, -1)
out_pyramid.append(corr)
out = torch.cat(out_pyramid, dim=-1)
return out.permute(0, 3, 1, 2).contiguous().float()
@staticmethod
def corr(fmap1, fmap2):
batch, dim, ht, wd = fmap1.shape
fmap1 = fmap1.view(batch, dim, ht * wd)
fmap2 = fmap2.view(batch, dim, ht * wd)
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
corr = corr.view(batch, ht, wd, 1, ht, wd)
return corr / torch.sqrt(torch.tensor(dim).float())
class AlternateCorrBlock:
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
self.num_levels = num_levels
self.radius = radius
self.pyramid = [(fmap1, fmap2)]
for i in range(self.num_levels):
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
self.pyramid.append((fmap1, fmap2))
def __call__(self, coords):
coords = coords.permute(0, 2, 3, 1)
B, H, W, _ = coords.shape
dim = self.pyramid[0][0].shape[1]
corr_list = []
for i in range(self.num_levels):
r = self.radius
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
corr_list.append(corr.squeeze(1))
corr = torch.stack(corr_list, dim=1)
corr = corr.reshape(B, -1, H, W)
return corr / torch.sqrt(torch.tensor(dim).float())

View File

@@ -0,0 +1,288 @@
# The implementation is adopted from RAFT,
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
import torch
import torch.nn as nn
import torch.nn.functional as F
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BottleneckBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(BottleneckBlock, self).__init__()
self.conv1 = nn.Conv2d(
in_planes, planes // 4, kernel_size=1, padding=0)
self.conv2 = nn.Conv2d(
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride)
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes // 4)
self.norm2 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes // 4)
self.norm3 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes)
if not stride == 1:
self.norm4 = nn.GroupNorm(
num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes // 4)
self.norm2 = nn.BatchNorm2d(planes // 4)
self.norm3 = nn.BatchNorm2d(planes)
if not stride == 1:
self.norm4 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes // 4)
self.norm2 = nn.InstanceNorm2d(planes // 4)
self.norm3 = nn.InstanceNorm2d(planes)
if not stride == 1:
self.norm4 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
if not stride == 1:
self.norm4 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
self.norm4)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
y = self.relu(self.norm3(self.conv3(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=2)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(
self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x
class SmallEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(SmallEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(32)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(32)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 32
self.layer1 = self._make_layer(32, stride=1)
self.layer2 = self._make_layer(64, stride=2)
self.layer3 = self._make_layer(96, stride=2)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m,
(nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = BottleneckBlock(
self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.training and self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, [batch_dim, batch_dim], dim=0)
return x

View File

@@ -0,0 +1,157 @@
# The implementation is adopted from RAFT,
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.video_frame_interpolation.flow_model.corr import (
AlternateCorrBlock, CorrBlock)
from modelscope.models.cv.video_frame_interpolation.flow_model.extractor import (
BasicEncoder, SmallEncoder)
from modelscope.models.cv.video_frame_interpolation.flow_model.update import (
BasicUpdateBlock, SmallUpdateBlock)
from modelscope.models.cv.video_frame_interpolation.utils.utils import (
bilinear_sampler, coords_grid, upflow8)
autocast = torch.cuda.amp.autocast
class RAFT(nn.Module):
def __init__(self, args):
super(RAFT, self).__init__()
self.args = args
if args.small:
self.hidden_dim = hdim = 96
self.context_dim = cdim = 64
self.args.corr_levels = 4
self.args.corr_radius = 3
else:
self.hidden_dim = hdim = 128
self.context_dim = cdim = 128
self.args.corr_levels = 4
self.args.corr_radius = 4
if 'dropout' not in self.args:
self.args.dropout = 0
if 'alternate_corr' not in self.args:
self.args.alternate_corr = False
# feature network, context network, and update block
if args.small:
self.fnet = SmallEncoder(
output_dim=128, norm_fn='instance', dropout=self.args.dropout)
self.cnet = SmallEncoder(
output_dim=hdim + cdim,
norm_fn='none',
dropout=self.args.dropout)
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
else:
self.fnet = BasicEncoder(
output_dim=256, norm_fn='instance', dropout=self.args.dropout)
self.cnet = BasicEncoder(
output_dim=hdim + cdim,
norm_fn='batch',
dropout=self.args.dropout)
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def initialize_flow(self, img):
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
N, C, H, W = img.shape
coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
# optical flow computed as difference: flow = coords1 - coords0
return coords0, coords1
def upsample_flow(self, flow, mask):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
mask = mask.view(N, 1, 9, 8, 8, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, 8 * H, 8 * W)
def forward(self,
image1,
image2,
iters=12,
flow_init=None,
upsample=True,
test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.args.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
if self.args.alternate_corr:
corr_fn = AlternateCorrBlock(
fmap1, fmap2, radius=self.args.corr_radius)
else:
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
# run the context network
with autocast(enabled=self.args.mixed_precision):
cnet = self.cnet(image1)
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
net = torch.tanh(net)
inp = torch.relu(inp)
coords0, coords1 = self.initialize_flow(image1)
if flow_init is not None:
coords1 = coords1 + flow_init
flow_predictions = []
for itr in range(iters):
coords1 = coords1.detach()
corr = corr_fn(coords1) # index correlation volume
flow = coords1 - coords0
with autocast(enabled=self.args.mixed_precision):
net, up_mask, delta_flow = self.update_block(
net, inp, corr, flow)
# F(t+1) = F(t) + \Delta(t)
coords1 = coords1 + delta_flow
# upsample predictions
if up_mask is None:
flow_up = upflow8(coords1 - coords0)
else:
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
flow_predictions.append(flow_up)
if test_mode:
return coords1 - coords0, flow_up
return flow_predictions

View File

@@ -0,0 +1,160 @@
# The implementation is adopted from RAFT,
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class ConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(ConvGRU, self).__init__()
self.convz = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convr = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, 3, padding=1)
self.convq = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, 3, padding=1)
def forward(self, h, x):
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz(hx))
r = torch.sigmoid(self.convr(hx))
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192 + 128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convr1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convq1 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
self.convz2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convr2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
self.convq2 = nn.Conv2d(
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
h = (1 - z) * h + z * q
return h
class SmallMotionEncoder(nn.Module):
def __init__(self, args):
super(SmallMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
self.conv = nn.Conv2d(128, 80, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicMotionEncoder(nn.Module):
def __init__(self, args):
super(BasicMotionEncoder, self).__init__()
cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class SmallUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=96):
super(SmallUpdateBlock, self).__init__()
self.encoder = SmallMotionEncoder(args)
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
def forward(self, net, inp, corr, flow):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
return net, None, delta_flow
class BasicUpdateBlock(nn.Module):
def __init__(self, args, hidden_dim=128, input_dim=128):
super(BasicUpdateBlock, self).__init__()
self.args = args
self.encoder = BasicMotionEncoder(args)
self.gru = SepConvGRU(
hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True),
nn.Conv2d(256, 64 * 9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
motion_features = self.encoder(flow, corr)
inp = torch.cat([inp, motion_features], dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

View File

@@ -0,0 +1,434 @@
# Part of the implementation is borrowed and modified from RIFE,
# publicly available at https://github.com/megvii-research/ECCV2022-RIFE
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_
from modelscope.models.cv.video_frame_interpolation.interp_model.transformer_layers import (
RTFL, PatchEmbed, PatchUnEmbed)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
backwarp_tenGrid = {}
def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(
-1.0, 1.0, tenFlow.shape[3], device=device).view(
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1,
tenFlow.shape[2], -1)
tenVertical = torch.linspace(
-1.0, 1.0, tenFlow.shape[2],
device=device).view(1, 1, tenFlow.shape[2],
1).expand(tenFlow.shape[0], -1, -1,
tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical],
1).to(device)
tmp1 = tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0)
tmp2 = tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)
tenFlow = torch.cat([tmp1, tmp2], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(
input=tenInput,
grid=torch.clamp(g, -1, 1),
mode='bilinear',
padding_mode='border',
align_corners=True)
def conv_wo_act(in_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True), )
def conv(in_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=True), nn.PReLU(out_planes))
def conv_bn(in_planes,
out_planes,
kernel_size=3,
stride=1,
padding=1,
dilation=1):
return nn.Sequential(
nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=False), nn.BatchNorm2d(out_planes), nn.PReLU(out_planes))
class TransModel(nn.Module):
def __init__(self,
img_size=64,
patch_size=1,
embed_dim=64,
depths=[[3, 3]],
num_heads=[[2, 2]],
window_size=4,
mlp_ratio=2,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
norm_layer=nn.LayerNorm,
ape=False,
patch_norm=True,
use_checkpoint=False,
resi_connection='1conv',
use_crossattn=[[[False, False, False, False],
[True, True, True, True]]]):
super(TransModel, self).__init__()
self.embed_dim = embed_dim
self.ape = ape
self.patch_norm = patch_norm
self.num_features = embed_dim
self.mlp_ratio = mlp_ratio
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
# merge non-overlapping patches into image
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
norm_layer=norm_layer if self.patch_norm else None)
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(
torch.zeros(1, num_patches, embed_dim))
trunc_normal_(self.absolute_pos_embed, std=.02)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr0 = [
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[0]))
] # stochastic depth decay rule
self.layers0 = nn.ModuleList()
num_layers = len(depths[0])
for i_layer in range(num_layers):
layer = RTFL(
dim=embed_dim,
input_resolution=(patches_resolution[0],
patches_resolution[1]),
depth=depths[0][i_layer],
num_heads=num_heads[0][i_layer],
window_size=window_size,
mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr0[sum(depths[0][:i_layer]):sum(depths[0][:i_layer
+ 1])],
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=(img_size[0], img_size[1]),
patch_size=patch_size,
resi_connection=resi_connection,
use_crossattn=use_crossattn[0][i_layer])
self.layers0.append(layer)
self.norm = norm_layer(self.num_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'absolute_pos_embed'}
@torch.jit.ignore
def no_weight_decay_keywords(self):
return {'relative_position_bias_table'}
def forward_features(self, x, layers):
x_size = (x.shape[2], x.shape[3])
x = self.patch_embed(x)
if self.ape:
x = x + self.absolute_pos_embed
x = self.pos_drop(x)
if isinstance(layers, nn.ModuleList):
for layer in layers:
x = layer(x, x_size)
else:
x = layers(x, x_size)
x = self.norm(x) # B L C
x = self.patch_unembed(x, x_size)
return x
def forward(self, x):
out = self.forward_features(x, self.layers0)
return out
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
conv(c, c, 3, 1, 1),
)
self.trans = TransModel(
img_size=(128 // scale, 128 // scale),
patch_size=1,
embed_dim=c,
depths=[[3, 3]],
num_heads=[[2, 2]])
self.conv1 = nn.Sequential(
conv(c, c, 3, 1, 1),
conv(c, c, 3, 1, 1),
)
self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1)
self.conv2 = nn.Conv2d(4, 4, 3, 1, 1)
def forward(self, x, flow0, flow1):
if self.scale != 1:
x = F.interpolate(
x,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False)
flow0 = F.interpolate(
flow0,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False) * (1. / self.scale)
flow1 = F.interpolate(
flow1,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False) * (1. / self.scale)
x = torch.cat((x, flow0, flow1), 1)
x = self.conv0(x)
x = self.trans(x)
x = self.conv1(x) + x
# upsample 2.0
x = self.up(x)
# upsample 2.0
x = self.conv2(x)
flow = F.interpolate(
x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0
if self.scale != 1:
flow = F.interpolate(
flow,
scale_factor=self.scale,
mode='bilinear',
align_corners=False) * self.scale
flow0 = flow[:, :2, :, :]
flow1 = flow[:, 2:, :, :]
return flow0, flow1
class IFBlock_wo_Swin(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock_wo_Swin, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c // 2, 3, 2, 1),
conv(c // 2, c, 3, 2, 1),
)
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c))
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c))
self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1)
self.conv2 = nn.Conv2d(4, 4, 3, 1, 1)
def forward(self, x, flow0, flow1):
if self.scale != 1:
x = F.interpolate(
x,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False)
flow0 = F.interpolate(
flow0,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False) * (1. / self.scale)
flow1 = F.interpolate(
flow1,
scale_factor=1. / self.scale,
mode='bilinear',
align_corners=False) * (1. / self.scale)
x = torch.cat((x, flow0, flow1), 1)
x = self.conv0(x)
x = self.convblock1(x) + x
x = self.convblock2(x) + x
# upsample 2.0
x = self.up(x)
# upsample 2.0
x = self.conv2(x)
flow = F.interpolate(
x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0
if self.scale != 1:
flow = F.interpolate(
flow,
scale_factor=self.scale,
mode='bilinear',
align_corners=False) * self.scale
flow0 = flow[:, :2, :, :]
flow1 = flow[:, 2:, :, :]
return flow0, flow1
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block1 = IFBlock_wo_Swin(16, scale=4, c=128)
self.block2 = IFBlock(16, scale=2, c=64)
self.block3 = IFBlock(16, scale=1, c=32)
# flow0: flow from img0 to img1
# flow1: flow from img1 to img0
def forward(self, img0, img1, flow0, flow1, sc_mode=2):
if sc_mode == 0:
sc = 0.25
elif sc_mode == 1:
sc = 0.5
else:
sc = 1
if sc != 1:
img0_sc = F.interpolate(
img0, scale_factor=sc, mode='bilinear', align_corners=False)
img1_sc = F.interpolate(
img1, scale_factor=sc, mode='bilinear', align_corners=False)
flow0_sc = F.interpolate(
flow0, scale_factor=sc, mode='bilinear',
align_corners=False) * sc
flow1_sc = F.interpolate(
flow1, scale_factor=sc, mode='bilinear',
align_corners=False) * sc
else:
img0_sc = img0
img1_sc = img1
flow0_sc = flow0
flow1_sc = flow1
warped_img0 = warp(img1_sc, flow0_sc) # -> img0
warped_img1 = warp(img0_sc, flow1_sc) # -> img1
flow0_1, flow1_1 = self.block1(
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1),
flow0_sc, flow1_sc)
F0_2 = (flow0_sc + flow0_1)
F1_2 = (flow1_sc + flow1_1)
warped_img0 = warp(img1_sc, F0_2) # -> img0
warped_img1 = warp(img0_sc, F1_2) # -> img1
flow0_2, flow1_2 = self.block2(
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1), F0_2,
F1_2)
F0_3 = (F0_2 + flow0_2)
F1_3 = (F1_2 + flow1_2)
warped_img0 = warp(img1_sc, F0_3) # -> img0
warped_img1 = warp(img0_sc, F1_3) # -> img1
flow0_3, flow1_3 = self.block3(
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), dim=1),
F0_3, F1_3)
flow_res_0 = flow0_1 + flow0_2 + flow0_3
flow_res_1 = flow1_1 + flow1_2 + flow1_3
if sc != 1:
flow_res_0 = F.interpolate(
flow_res_0,
scale_factor=1 / sc,
mode='bilinear',
align_corners=False) / sc
flow_res_1 = F.interpolate(
flow_res_1,
scale_factor=1 / sc,
mode='bilinear',
align_corners=False) / sc
F0_4 = flow0 + flow_res_0
F1_4 = flow1 + flow_res_1
return F0_4, F1_4

View File

@@ -0,0 +1,127 @@
# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI
import torch
import torch.nn as nn
import torch.nn.functional as F
class down(nn.Module):
def __init__(self, inChannels, outChannels, filterSize):
super(down, self).__init__()
self.conv1 = nn.Conv2d(
inChannels,
outChannels,
filterSize,
stride=1,
padding=int((filterSize - 1) / 2))
self.conv2 = nn.Conv2d(
outChannels,
outChannels,
filterSize,
stride=1,
padding=int((filterSize - 1) / 2))
def forward(self, x):
x = F.avg_pool2d(x, 2)
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
x = F.leaky_relu(self.conv2(x), negative_slope=0.1)
return x
class up(nn.Module):
def __init__(self, inChannels, outChannels):
super(up, self).__init__()
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
self.conv2 = nn.Conv2d(
2 * outChannels, outChannels, 3, stride=1, padding=1)
def forward(self, x, skpCn):
x = F.interpolate(
x,
size=[skpCn.size(2), skpCn.size(3)],
mode='bilinear',
align_corners=False)
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
x = F.leaky_relu(
self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
return x
class Small_UNet(nn.Module):
def __init__(self, inChannels, outChannels):
super(Small_UNet, self).__init__()
self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3)
self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3)
self.down1 = down(32, 64, 5)
self.down2 = down(64, 128, 3)
self.down3 = down(128, 128, 3)
self.up1 = up(128, 128)
self.up2 = up(128, 64)
self.up3 = up(64, 32)
self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
def forward(self, x):
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
s1 = F.leaky_relu(self.conv2(x), negative_slope=0.1)
s2 = self.down1(s1)
s3 = self.down2(s2)
x = self.down3(s3)
x = self.up1(x, s3)
x = self.up2(x, s2)
x1 = self.up3(x, s1) # feature
x = self.conv3(x1) # flow
return x, x1
class Small_UNet_Ds(nn.Module):
def __init__(self, inChannels, outChannels):
super(Small_UNet_Ds, self).__init__()
self.conv1_1 = nn.Conv2d(inChannels, 32, 5, stride=1, padding=2)
self.conv1_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.conv2_1 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.conv2_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.down1 = down(32, 64, 5)
self.down2 = down(64, 128, 3)
self.down3 = down(128, 128, 3)
self.up1 = up(128, 128)
self.up2 = up(128, 64)
self.up3 = up(64, 32)
self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
self.conv4 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
def forward(self, x):
x0 = F.leaky_relu(self.conv1_1(x), negative_slope=0.1)
x0 = F.leaky_relu(self.conv1_2(x0), negative_slope=0.1)
x = F.interpolate(
x0,
size=[x0.size(2) // 2, x0.size(3) // 2],
mode='bilinear',
align_corners=False)
x = F.leaky_relu(self.conv2_1(x), negative_slope=0.1)
s1 = F.leaky_relu(self.conv2_2(x), negative_slope=0.1)
s2 = self.down1(s1)
s3 = self.down2(s2)
x = self.down3(s3)
x = self.up1(x, s3)
x = self.up2(x, s2)
x1 = self.up3(x, s1)
x1 = F.interpolate(
x1,
size=[x0.size(2), x0.size(3)],
mode='bilinear',
align_corners=False)
x1 = F.leaky_relu(self.conv3(x1), negative_slope=0.1) # feature
x = self.conv4(x1) # flow
return x, x1

View File

@@ -0,0 +1,115 @@
# The implementation is adopted from QVI,
# made publicly available at https://github.com/xuxy09/QVI
# class WarpLayer warps image x based on optical flow flo.
import torch
import torch.nn as nn
import torch.nn.functional as F
class FlowReversal(nn.Module):
"""docstring for WarpLayer"""
def __init__(self, ):
super(FlowReversal, self).__init__()
def forward(self, img, flo):
"""
-img: image (N, C, H, W)
-flo: optical flow (N, 2, H, W)
elements of flo is in [0, H] and [0, W] for dx, dy
"""
N, C, _, _ = img.size()
# translate start-point optical flow to end-point optical flow
y = flo[:, 0:1:, :]
x = flo[:, 1:2, :, :]
x = x.repeat(1, C, 1, 1)
y = y.repeat(1, C, 1, 1)
x1 = torch.floor(x)
x2 = x1 + 1
y1 = torch.floor(y)
y2 = y1 + 1
# firstly, get gaussian weights
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)
# secondly, sample each weighted corner
img11, o11 = self.sample_one(img, x1, y1, w11)
img12, o12 = self.sample_one(img, x1, y2, w12)
img21, o21 = self.sample_one(img, x2, y1, w21)
img22, o22 = self.sample_one(img, x2, y2, w22)
imgw = img11 + img12 + img21 + img22
o = o11 + o12 + o21 + o22
return imgw, o
def get_gaussian_weights(self, x, y, x1, x2, y1, y2):
w11 = torch.exp(-((x - x1)**2 + (y - y1)**2))
w12 = torch.exp(-((x - x1)**2 + (y - y2)**2))
w21 = torch.exp(-((x - x2)**2 + (y - y1)**2))
w22 = torch.exp(-((x - x2)**2 + (y - y2)**2))
return w11, w12, w21, w22
def sample_one(self, img, shiftx, shifty, weight):
"""
Input:
-img (N, C, H, W)
-shiftx, shifty (N, c, H, W)
"""
N, C, H, W = img.size()
# flatten all (all restored as Tensors)
flat_shiftx = shiftx.view(-1)
flat_shifty = shifty.view(-1)
flat_basex = torch.arange(
0, H, requires_grad=False).view(-1, 1)[None,
None].cuda().long().repeat(
N, C, 1, W).view(-1)
flat_basey = torch.arange(
0, W, requires_grad=False).view(1, -1)[None,
None].cuda().long().repeat(
N, C, H, 1).view(-1)
flat_weight = weight.view(-1)
flat_img = img.view(-1)
idxn = torch.arange(
0, N,
requires_grad=False).view(N, 1, 1,
1).long().cuda().repeat(1, C, H,
W).view(-1)
idxc = torch.arange(
0, C,
requires_grad=False).view(1, C, 1,
1).long().cuda().repeat(N, 1, H,
W).view(-1)
idxx = flat_shiftx.long() + flat_basex
idxy = flat_shifty.long() + flat_basey
mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)
ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy)
ids_mask = torch.masked_select(ids, mask).clone().cuda()
img_warp = torch.zeros([
N * C * H * W,
]).cuda()
img_warp.put_(
ids_mask,
torch.masked_select(flat_img * flat_weight, mask),
accumulate=True)
one_warp = torch.zeros([
N * C * H * W,
]).cuda()
one_warp.put_(
ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)

View File

@@ -0,0 +1,488 @@
# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.video_frame_interpolation.interp_model.flow_reversal import \
FlowReversal
from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \
IFNet
from modelscope.models.cv.video_frame_interpolation.interp_model.UNet import \
Small_UNet_Ds
class AcFusionLayer(nn.Module):
def __init__(self, ):
super(AcFusionLayer, self).__init__()
def forward(self, flo10, flo12, flo21, flo23, t=0.5):
return 0.5 * ((t + t**2) * flo12 - (t - t**2) * flo10), \
0.5 * (((1 - t) + (1 - t)**2) * flo21 - ((1 - t) - (1 - t)**2) * flo23)
# return 0.375 * flo12 - 0.125 * flo10, 0.375 * flo21 - 0.125 * flo23
class Get_gradient(nn.Module):
def __init__(self):
super(Get_gradient, self).__init__()
kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False)
def forward(self, x):
x0 = x[:, 0] # R
x1 = x[:, 1] # G
x2 = x[:, 2] # B
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1)
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1)
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1)
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1)
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
x = torch.cat([x0, x1, x2], dim=1)
return x
class LowPassFilter(nn.Module):
def __init__(self):
super(LowPassFilter, self).__init__()
kernel_lpf = [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1]]
kernel_lpf = torch.FloatTensor(kernel_lpf).unsqueeze(0).unsqueeze(
0) / 49
self.weight_lpf = nn.Parameter(data=kernel_lpf, requires_grad=False)
def forward(self, x):
x0 = x[:, 0]
x1 = x[:, 1]
y0 = F.conv2d(x0.unsqueeze(1), self.weight_lpf, padding=3)
y1 = F.conv2d(x1.unsqueeze(1), self.weight_lpf, padding=3)
y = torch.cat([y0, y1], dim=1)
return y
def backwarp(img, flow):
_, _, H, W = img.size()
u = flow[:, 0, :, :]
v = flow[:, 1, :, :]
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
gridX = torch.tensor(
gridX,
requires_grad=False,
).cuda()
gridY = torch.tensor(
gridY,
requires_grad=False,
).cuda()
x = gridX.unsqueeze(0).expand_as(u).float() + u
y = gridY.unsqueeze(0).expand_as(v).float() + v
x = 2 * (x / (W - 1) - 0.5)
y = 2 * (y / (H - 1) - 0.5)
grid = torch.stack((x, y), dim=3)
imgOut = torch.nn.functional.grid_sample(
img, grid, padding_mode='border', align_corners=True)
return imgOut
class SmallMaskNet(nn.Module):
"""A three-layer network for predicting mask"""
def __init__(self, input, output):
super(SmallMaskNet, self).__init__()
self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
self.conv2 = nn.Conv2d(32, 16, 3, padding=1)
self.conv3 = nn.Conv2d(16, output, 3, padding=1)
def forward(self, x):
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
x = F.leaky_relu(self.conv2(x), negative_slope=0.1)
x = self.conv3(x)
return x
class StaticMaskNet(nn.Module):
"""static mask"""
def __init__(self, input, output):
super(StaticMaskNet, self).__init__()
modules_body = []
modules_body.append(
nn.Conv2d(
in_channels=input,
out_channels=32,
kernel_size=3,
stride=1,
padding=1))
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
modules_body.append(
nn.Conv2d(
in_channels=32,
out_channels=32,
kernel_size=3,
stride=1,
padding=1))
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
modules_body.append(
nn.Conv2d(
in_channels=32,
out_channels=16,
kernel_size=3,
stride=1,
padding=1))
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
modules_body.append(
nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=3,
stride=1,
padding=1))
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
modules_body.append(
nn.Conv2d(
in_channels=16,
out_channels=output,
kernel_size=3,
stride=1,
padding=1))
modules_body.append(nn.Sigmoid())
self.body = nn.Sequential(*modules_body)
def forward(self, x):
y = self.body(x)
return y
def tensor_erode(bin_img, ksize=5):
B, C, H, W = bin_img.shape
pad = (ksize - 1) // 2
bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
patches = bin_img.unfold(dimension=2, size=ksize, step=1)
patches = patches.unfold(dimension=3, size=ksize, step=1)
eroded, _ = patches.reshape(B, C, H, W, -1).max(dim=-1)
return eroded
class QVI_inter_Ds(nn.Module):
"""Given flow, implement Quadratic Video Interpolation"""
def __init__(self, debug_en=False, is_training=False):
super(QVI_inter_Ds, self).__init__()
self.acc = AcFusionLayer()
self.fwarp = FlowReversal()
self.refinenet = Small_UNet_Ds(20, 8)
self.masknet = SmallMaskNet(38, 1)
self.staticnet = StaticMaskNet(56, 1)
self.lpfilter = LowPassFilter()
self.get_grad = Get_gradient()
self.debug_en = debug_en
self.is_training = is_training
def fill_flow_hole(self, ft, norm, ft_fill):
(N, C, H, W) = ft.shape
ft[norm == 0] = ft_fill[norm == 0]
ft_1 = self.lpfilter(ft.clone())
ft_ds = torch.nn.functional.interpolate(
input=ft_1,
size=(H // 4, W // 4),
mode='bilinear',
align_corners=False)
ft_up = torch.nn.functional.interpolate(
input=ft_ds, size=(H, W), mode='bilinear', align_corners=False)
ft[norm == 0] = ft_up[norm == 0]
return ft
def forward(self, F10_Ds, F12_Ds, F21_Ds, F23_Ds, I1_Ds, I2_Ds, I1, I2, t):
if F12_Ds is None or F21_Ds is None:
return I1
if F10_Ds is not None and F23_Ds is not None:
F1t_Ds, F2t_Ds = self.acc(F10_Ds, F12_Ds, F21_Ds, F23_Ds, t)
else:
F1t_Ds = t * F12_Ds
F2t_Ds = (1 - t) * F21_Ds
# Flow Reversal
F1t_Ds2 = F.interpolate(
F1t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3
F2t_Ds2 = F.interpolate(
F2t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3
Ft1_Ds2, norm1_Ds2 = self.fwarp(F1t_Ds2, F1t_Ds2)
Ft1_Ds2 = -Ft1_Ds2
Ft2_Ds2, norm2_Ds2 = self.fwarp(F2t_Ds2, F2t_Ds2)
Ft2_Ds2 = -Ft2_Ds2
Ft1_Ds2[norm1_Ds2 > 0] \
= Ft1_Ds2[norm1_Ds2 > 0] / norm1_Ds2[norm1_Ds2 > 0].clone()
Ft2_Ds2[norm2_Ds2 > 0] \
= Ft2_Ds2[norm2_Ds2 > 0] / norm2_Ds2[norm2_Ds2 > 0].clone()
if 1:
Ft1_Ds2_fill = -F1t_Ds2
Ft2_Ds2_fill = -F2t_Ds2
Ft1_Ds2 = self.fill_flow_hole(Ft1_Ds2, norm1_Ds2, Ft1_Ds2_fill)
Ft2_Ds2 = self.fill_flow_hole(Ft2_Ds2, norm2_Ds2, Ft2_Ds2_fill)
Ft1_Ds = F.interpolate(
Ft1_Ds2, size=[F1t_Ds.size(2), F1t_Ds.size(3)], mode='nearest') * 3
Ft2_Ds = F.interpolate(
Ft2_Ds2, size=[F2t_Ds.size(2), F2t_Ds.size(3)], mode='nearest') * 3
I1t_Ds = backwarp(I1_Ds, Ft1_Ds)
I2t_Ds = backwarp(I2_Ds, Ft2_Ds)
output_Ds, feature_Ds = self.refinenet(
torch.cat(
[I1_Ds, I2_Ds, I1t_Ds, I2t_Ds, F12_Ds, F21_Ds, Ft1_Ds, Ft2_Ds],
dim=1))
# Adaptive filtering
Ft1r_Ds = backwarp(
Ft1_Ds, 10 * torch.tanh(output_Ds[:, 4:6])) + output_Ds[:, :2]
Ft2r_Ds = backwarp(
Ft2_Ds, 10 * torch.tanh(output_Ds[:, 6:8])) + output_Ds[:, 2:4]
# warping and fusing
I1tf_Ds = backwarp(I1_Ds, Ft1r_Ds)
I2tf_Ds = backwarp(I2_Ds, Ft2r_Ds)
G1_Ds = self.get_grad(I1_Ds)
G2_Ds = self.get_grad(I2_Ds)
G1tf_Ds = backwarp(G1_Ds, Ft1r_Ds)
G2tf_Ds = backwarp(G2_Ds, Ft2r_Ds)
M_Ds = torch.sigmoid(
self.masknet(torch.cat([I1tf_Ds, I2tf_Ds, feature_Ds],
dim=1))).repeat(1, 3, 1, 1)
Ft1r = F.interpolate(
Ft1r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False)
Ft2r = F.interpolate(
Ft2r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False)
I1tf = backwarp(I1, Ft1r)
I2tf = backwarp(I2, Ft2r)
M = F.interpolate(
M_Ds, scale_factor=2, mode='bilinear', align_corners=False)
# fuse
It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \
/ ((1 - t) * M + t * (1 - M)).clone()
# static blending
It_static = (1 - t) * I1 + t * I2
tmp = torch.cat((I1tf_Ds, I2tf_Ds, G1tf_Ds, G2tf_Ds, I1_Ds, I2_Ds,
G1_Ds, G2_Ds, feature_Ds),
dim=1)
M_static_Ds = self.staticnet(tmp)
M_static_dilate = tensor_erode(M_static_Ds)
M_static_dilate = tensor_erode(M_static_dilate)
M_static = F.interpolate(
M_static_dilate,
scale_factor=2,
mode='bilinear',
align_corners=False)
It_warp = (1 - M_static) * It_warp + M_static * It_static
if self.is_training:
return It_warp, Ft1r, Ft2r
else:
if self.debug_en:
return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r
else:
return It_warp
class QVI_inter(nn.Module):
"""Given flow, implement Quadratic Video Interpolation"""
def __init__(self, debug_en=False, is_training=False):
super(QVI_inter, self).__init__()
self.acc = AcFusionLayer()
self.fwarp = FlowReversal()
self.refinenet = Small_UNet_Ds(20, 8)
self.masknet = SmallMaskNet(38, 1)
self.staticnet = StaticMaskNet(56, 1)
self.lpfilter = LowPassFilter()
self.get_grad = Get_gradient()
self.debug_en = debug_en
self.is_training = is_training
def fill_flow_hole(self, ft, norm, ft_fill):
(N, C, H, W) = ft.shape
ft[norm == 0] = ft_fill[norm == 0]
ft_1 = self.lpfilter(ft.clone())
ft_ds = torch.nn.functional.interpolate(
input=ft_1,
size=(H // 4, W // 4),
mode='bilinear',
align_corners=False)
ft_up = torch.nn.functional.interpolate(
input=ft_ds, size=(H, W), mode='bilinear', align_corners=False)
ft[norm == 0] = ft_up[norm == 0]
return ft
def forward(self, F10, F12, F21, F23, I1, I2, t):
if F12 is None or F21 is None:
return I1
if F10 is not None and F23 is not None:
F1t, F2t = self.acc(F10, F12, F21, F23, t)
else:
F1t = t * F12
F2t = (1 - t) * F21
# Flow Reversal
F1t_Ds = F.interpolate(F1t, scale_factor=1.0 / 3, mode='nearest') / 3
F2t_Ds = F.interpolate(F2t, scale_factor=1.0 / 3, mode='nearest') / 3
Ft1_Ds, norm1_Ds = self.fwarp(F1t_Ds, F1t_Ds)
Ft1_Ds = -Ft1_Ds
Ft2_Ds, norm2_Ds = self.fwarp(F2t_Ds, F2t_Ds)
Ft2_Ds = -Ft2_Ds
Ft1_Ds[norm1_Ds > 0] \
= Ft1_Ds[norm1_Ds > 0] / norm1_Ds[norm1_Ds > 0].clone()
Ft2_Ds[norm2_Ds > 0] \
= Ft2_Ds[norm2_Ds > 0] / norm2_Ds[norm2_Ds > 0].clone()
if 1:
Ft1_fill = -F1t_Ds
Ft2_fill = -F2t_Ds
Ft1_Ds = self.fill_flow_hole(Ft1_Ds, norm1_Ds, Ft1_fill)
Ft2_Ds = self.fill_flow_hole(Ft2_Ds, norm2_Ds, Ft2_fill)
Ft1 = F.interpolate(
Ft1_Ds, size=[F1t.size(2), F1t.size(3)], mode='nearest') * 3
Ft2 = F.interpolate(
Ft2_Ds, size=[F2t.size(2), F2t.size(3)], mode='nearest') * 3
I1t = backwarp(I1, Ft1)
I2t = backwarp(I2, Ft2)
output, feature = self.refinenet(
torch.cat([I1, I2, I1t, I2t, F12, F21, Ft1, Ft2], dim=1))
# Adaptive filtering
Ft1r = backwarp(Ft1, 10 * torch.tanh(output[:, 4:6])) + output[:, :2]
Ft2r = backwarp(Ft2, 10 * torch.tanh(output[:, 6:8])) + output[:, 2:4]
# warping and fusing
I1tf = backwarp(I1, Ft1r)
I2tf = backwarp(I2, Ft2r)
M = torch.sigmoid(
self.masknet(torch.cat([I1tf, I2tf, feature],
dim=1))).repeat(1, 3, 1, 1)
It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \
/ ((1 - t) * M + t * (1 - M)).clone()
G1 = self.get_grad(I1)
G2 = self.get_grad(I2)
G1tf = backwarp(G1, Ft1r)
G2tf = backwarp(G2, Ft2r)
# static blending
It_static = (1 - t) * I1 + t * I2
M_static = self.staticnet(
torch.cat([I1tf, I2tf, G1tf, G2tf, I1, I2, G1, G2, feature],
dim=1))
M_static_dilate = tensor_erode(M_static)
M_static_dilate = tensor_erode(M_static_dilate)
It_warp = (1 - M_static_dilate) * It_warp + M_static_dilate * It_static
if self.is_training:
return It_warp, Ft1r, Ft2r
else:
if self.debug_en:
return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r
else:
return It_warp
class InterpNetDs(nn.Module):
def __init__(self, debug_en=False, is_training=False):
super(InterpNetDs, self).__init__()
self.ifnet = IFNet()
self.internet = QVI_inter_Ds(
debug_en=debug_en, is_training=is_training)
def forward(self,
img1,
img2,
F10_up,
F12_up,
F21_up,
F23_up,
UHD=2,
timestep=0.5):
F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD)
It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep)
return It_warp
class InterpNet(nn.Module):
def __init__(self, debug_en=False, is_training=False):
super(InterpNet, self).__init__()
self.ifnet = IFNet()
self.internet = QVI_inter(debug_en=debug_en, is_training=is_training)
def forward(self,
img1,
img2,
F10_up,
F12_up,
F21_up,
F23_up,
UHD=2,
timestep=0.5):
F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD)
It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep)
return It_warp

View File

@@ -0,0 +1,989 @@
# The implementation is adopted from VFIformer,
# made publicly available at https://github.com/dvlab-research/VFIformer
# -----------------------------------------------------------------------------------
# modified from:
# SwinIR: Image Restoration Using Swin Transformer, https://github.com/JingyunLiang/SwinIR
# -----------------------------------------------------------------------------------
import functools
import math
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class Mlp(nn.Module):
def __init__(self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def window_partition(x, window_size):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B, H, W, C = x.shape
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
C)
windows = x.permute(0, 1, 3, 2, 4,
5).contiguous().view(-1, window_size, window_size, C)
return windows
def window_reverse(windows, window_size, H, W):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.view(B, H // window_size, W // window_size, window_size,
window_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :,
0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class WindowCrossAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self,
dim,
window_size,
num_heads,
qkv_bias=True,
qk_scale=None,
attn_drop=0.,
proj_drop=0.):
super().__init__()
self.dim = dim
self.window_size = window_size # Wh, Ww
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
# define a parameter table of relative position bias
self.relative_position_bias_table_x = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
self.relative_position_bias_table_y = nn.Parameter(
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
coords_w = torch.arange(self.window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :,
None] - coords_flatten[:,
None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :,
0] += self.window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += self.window_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
self.register_buffer('relative_position_index',
relative_position_index)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.merge1 = nn.Linear(dim * 2, dim)
self.merge2 = nn.Linear(dim, dim)
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
trunc_normal_(self.relative_position_bias_table_x, std=.02)
trunc_normal_(self.relative_position_bias_table_y, std=.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, y, mask_x=None, mask_y=None):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[
2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table_x[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask_x is not None:
nW = mask_x.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask_x.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
B_, N, C = y.shape
kv = self.kv(y).reshape(B_, N, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[
1] # make torchscript happy (cannot use tensor as tuple)
attn = (q @ k.transpose(-2, -1).contiguous())
relative_position_bias = self.relative_position_bias_table_y[
self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1],
self.window_size[0] * self.window_size[1],
-1) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
if mask_y is not None:
nW = mask_y.shape[0]
attn = attn.view(B_ // nW, nW, self.num_heads, N,
N) + mask_y.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
y = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
x = self.merge2(self.act(self.merge1(torch.cat([x, y], dim=-1)))) + x
x = self.proj(x)
x = self.proj_drop(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
def flops(self, N):
# calculate flops for 1 window with token length of N
flops = 0
# qkv = self.qkv(x)
flops += N * self.dim * 3 * self.dim
# attn = (q @ k.transpose(-2, -1))
flops += self.num_heads * N * (self.dim // self.num_heads) * N
# x = (attn @ v)
flops += self.num_heads * N * N * (self.dim // self.num_heads)
# x = self.proj(x)
flops += N * self.dim * self.dim
return flops
class TFL(nn.Module):
def __init__(self,
dim,
input_resolution,
num_heads,
window_size=7,
shift_size=0,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
use_crossattn=False):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.use_crossattn = use_crossattn
if min(self.input_resolution) <= self.window_size:
# if window size is larger than input resolution, we don't partition windows
self.shift_size = 0
self.window_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
self.norm1 = norm_layer(dim)
if not use_crossattn:
self.attn = WindowAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
else:
self.attn = WindowCrossAttention(
dim,
window_size=to_2tuple(self.window_size),
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop)
self.drop_path = DropPath(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
if self.shift_size > 0:
if not use_crossattn:
attn_mask = self.calculate_mask(self.input_resolution)
self.register_buffer('attn_mask', attn_mask)
else:
attn_mask_x = self.calculate_mask(self.input_resolution)
attn_mask_y = self.calculate_mask2(self.input_resolution)
self.register_buffer('attn_mask_x', attn_mask_x)
self.register_buffer('attn_mask_y', attn_mask_y)
else:
if not use_crossattn:
attn_mask = None
self.register_buffer('attn_mask', attn_mask)
else:
attn_mask_x = None
attn_mask_y = None
self.register_buffer('attn_mask_x', attn_mask_x)
self.register_buffer('attn_mask_y', attn_mask_y)
def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
return attn_mask
def calculate_mask2(self, x_size):
# calculate attention mask for SW-MSA
H, W = x_size
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
h_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
slice(-self.window_size,
-self.shift_size), slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(
img_mask, self.window_size) # nW, window_size, window_size, 1
mask_windows = mask_windows.view(-1,
self.window_size * self.window_size)
# downscale
img_mask_down = F.interpolate(
img_mask.permute(0, 3, 1, 2).contiguous(),
scale_factor=0.5,
mode='bilinear',
align_corners=False)
img_mask_down = F.pad(
img_mask_down, (self.window_size // 4, self.window_size // 4,
self.window_size // 4, self.window_size // 4),
mode='reflect')
mask_windows_down = F.unfold(
img_mask_down,
kernel_size=self.window_size,
dilation=1,
padding=0,
stride=self.window_size // 2)
mask_windows_down = mask_windows_down.view(
self.window_size * self.window_size,
-1).permute(1, 0).contiguous() # nW, window_size*window_size
attn_mask = mask_windows_down.unsqueeze(1) - mask_windows.unsqueeze(
2) # nW, window_size*window_size, window_size*window_size
attn_mask = attn_mask.masked_fill(attn_mask != 0,
float(-100.0)).masked_fill(
attn_mask == 0, float(0.0))
return attn_mask
def forward(self, x, x_size):
H, W = x_size
B, L, C = x.shape
shortcut = x
x = self.norm1(x)
x = x.view(B, H, W, C)
# cyclic shift
if self.shift_size > 0:
shifted_x = torch.roll(
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
# partition windows
x_windows = window_partition(
shifted_x, self.window_size) # nW*B, window_size, window_size, C
x_windows = x_windows.view(-1, self.window_size * self.window_size,
C) # nW*B, window_size*window_size, C
if not self.use_crossattn:
if self.input_resolution == x_size:
attn_windows = self.attn(
x_windows,
mask=self.attn_mask) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(
x_windows, mask=self.calculate_mask(x_size).to(x.device))
else:
shifted_x_down = F.interpolate(
shifted_x.permute(0, 3, 1, 2).contiguous(),
scale_factor=0.5,
mode='bilinear',
align_corners=False)
shifted_x_down = F.pad(
shifted_x_down, (self.window_size // 4, self.window_size // 4,
self.window_size // 4, self.window_size // 4),
mode='reflect')
x_windows_down = F.unfold(
shifted_x_down,
kernel_size=self.window_size,
dilation=1,
padding=0,
stride=self.window_size // 2)
x_windows_down = x_windows_down.view(
B, C, self.window_size * self.window_size, -1)
x_windows_down = x_windows_down.permute(
0, 3, 2,
1).contiguous().view(-1, self.window_size * self.window_size,
C) # nW*B, window_size*window_size, C
if self.input_resolution == x_size:
attn_windows = self.attn(
x_windows,
x_windows_down,
mask_x=self.attn_mask_x,
mask_y=self.attn_mask_y
) # nW*B, window_size*window_size, C
else:
attn_windows = self.attn(
x_windows,
x_windows_down,
mask_x=self.calculate_mask(x_size).to(x.device),
mask_y=self.calculate_mask2(x_size).to(x.device))
# merge windows
attn_windows = attn_windows.view(-1, self.window_size,
self.window_size, C)
shifted_x = window_reverse(attn_windows, self.window_size, H,
W) # B H' W' C
# reverse cyclic shift
if self.shift_size > 0:
x = torch.roll(
shifted_x,
shifts=(self.shift_size, self.shift_size),
dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
# FFN
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}'
def flops(self):
flops = 0
H, W = self.input_resolution
# norm1
flops += self.dim * H * W
# W-MSA/SW-MSA
nW = H * W / self.window_size / self.window_size
flops += nW * self.attn.flops(self.window_size * self.window_size)
# mlp
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
# norm2
flops += self.dim * H * W
return flops
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.input_resolution = input_resolution
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x):
"""
x: B, H*W, C
"""
H, W = self.input_resolution
B, L, C = x.shape
assert L == H * W, 'input feature has wrong size'
assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.'
x = x.view(B, H, W, C)
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)
return x
def extra_repr(self) -> str:
return f'input_resolution={self.input_resolution}, dim={self.dim}'
def flops(self):
H, W = self.input_resolution
flops = H * W * self.dim
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
return flops
class BasicLayer(nn.Module):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
use_crossattn=None):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
self.depth = depth
self.use_checkpoint = use_checkpoint
if use_crossattn is None:
use_crossattn = [False for i in range(depth)]
# build blocks
self.blocks = nn.ModuleList([
TFL(dim=dim,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=0 if (i % 2 == 0) else window_size // 2,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i]
if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer,
use_crossattn=use_crossattn[i]) for i in range(depth)
])
# patch merging layer
if downsample is not None:
self.downsample = downsample(
input_resolution, dim=dim, norm_layer=norm_layer)
else:
self.downsample = None
def forward(self, x, x_size):
for blk in self.blocks:
if self.use_checkpoint:
x = checkpoint.checkpoint(blk, x, x_size)
else:
x = blk(x, x_size)
if self.downsample is not None:
x = self.downsample(x)
return x
def extra_repr(self) -> str:
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
def flops(self):
flops = 0
for blk in self.blocks:
flops += blk.flops()
if self.downsample is not None:
flops += self.downsample.flops()
return flops
class RTFL(nn.Module):
def __init__(self,
dim,
input_resolution,
depth,
num_heads,
window_size,
mlp_ratio=4.,
qkv_bias=True,
qk_scale=None,
drop=0.,
attn_drop=0.,
drop_path=0.,
norm_layer=nn.LayerNorm,
downsample=None,
use_checkpoint=False,
img_size=224,
patch_size=4,
resi_connection='1conv',
use_crossattn=None):
super(RTFL, self).__init__()
self.dim = dim
self.input_resolution = input_resolution
self.use_crossattn = use_crossattn
self.residual_group = BasicLayer(
dim=dim,
input_resolution=input_resolution,
depth=depth,
num_heads=num_heads,
window_size=window_size,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop,
attn_drop=attn_drop,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
use_checkpoint=use_checkpoint,
use_crossattn=use_crossattn)
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(
nn.Conv2d(dim, dim // 4, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
nn.LeakyReLU(negative_slope=0.2, inplace=True),
nn.Conv2d(dim // 4, dim, 3, 1, 1))
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None)
self.patch_unembed = PatchUnEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=0,
embed_dim=dim,
norm_layer=None)
def forward(self, x, x_size):
return self.patch_embed(
self.conv(
self.patch_unembed(self.residual_group(x, x_size),
x_size))) + x
def flops(self):
flops = 0
flops += self.residual_group.flops()
H, W = self.input_resolution
flops += H * W * self.dim * self.dim * 9
flops += self.patch_embed.flops()
flops += self.patch_unembed.flops()
return flops
class PatchEmbed(nn.Module):
r""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
if norm_layer is not None:
self.norm = norm_layer(embed_dim)
else:
self.norm = None
def forward(self, x):
x = x.flatten(2).transpose(1, 2).contiguous() # B Ph*Pw C
if self.norm is not None:
x = self.norm(x)
return x
def flops(self):
flops = 0
H, W = self.img_size
if self.norm is not None:
flops += H * W * self.embed_dim
return flops
class PatchUnEmbed(nn.Module):
r""" Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def __init__(self,
img_size=224,
patch_size=4,
in_chans=3,
embed_dim=96,
norm_layer=None):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
def forward(self, x, x_size):
B, HW, C = x.shape
x = x.transpose(1, 2).contiguous().view(B, self.embed_dim, x_size[0],
x_size[1]) # B Ph*Pw C
return x
def flops(self):
flops = 0
return flops
class Upsample(nn.Sequential):
"""Upsample module.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat):
m = []
if (scale & (scale - 1)) == 0: # scale = 2^n
for _ in range(int(math.log(scale, 2))):
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(2))
elif scale == 3:
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
m.append(nn.PixelShuffle(3))
else:
raise ValueError(f'scale {scale} is not supported. '
'Supported scales: 2^n and 3.')
super(Upsample, self).__init__(*m)
class UpsampleOneStep(nn.Sequential):
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
Used in lightweight SR to save parameters.
Args:
scale (int): Scale factor. Supported scales: 2^n and 3.
num_feat (int): Channel number of intermediate features.
"""
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
self.num_feat = num_feat
self.input_resolution = input_resolution
m = []
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
m.append(nn.PixelShuffle(scale))
super(UpsampleOneStep, self).__init__(*m)
def flops(self):
H, W = self.input_resolution
flops = H * W * self.num_feat * 3 * 9
return flops

View File

@@ -0,0 +1,97 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
def calc_hist(img_tensor):
hist = torch.histc(img_tensor, bins=64, min=0, max=255)
return hist / img_tensor.numel()
def do_scene_detect(F01_tensor, F10_tensor, img0_tensor, img1_tensor):
device = img0_tensor.device
scene_change = False
img0_tensor = img0_tensor.clone()
img1_tensor = img1_tensor.clone()
img0_gray = 0.299 * img0_tensor[:, 0:
1] + 0.587 * img0_tensor[:, 1:
2] + 0.114 * img0_tensor[:,
2:
3]
img1_gray = 0.299 * img1_tensor[:, 0:
1] + 0.587 * img1_tensor[:, 1:
2] + 0.114 * img1_tensor[:,
2:
3]
img0_gray = torch.clamp(img0_gray, 0, 255).byte().float().cpu()
img1_gray = torch.clamp(img1_gray, 0, 255).byte().float().cpu()
hist0 = calc_hist(img0_gray)
hist1 = calc_hist(img1_gray)
diff = torch.abs(hist0 - hist1)
diff[diff < 0.01] = 0
if torch.sum(diff) > 0.8 or diff.max() > 0.4:
return True
img0_gray = img0_gray.to(device)
img1_gray = img1_gray.to(device)
# second stage: detect mv and pix mismatch
(n, c, h, w) = F01_tensor.size()
scale_x = w / 1920
scale_y = h / 1080
# compare mv
(y, x) = torch.meshgrid(torch.arange(h), torch.arange(w))
(y_grid, x_grid) = torch.meshgrid(
torch.arange(64, h - 64, 8), torch.arange(64, w - 64, 8))
x = x.to(device)
y = y.to(device)
y_grid = y_grid.to(device)
x_grid = x_grid.to(device)
fx = F01_tensor[0, 0]
fy = F01_tensor[0, 1]
x_ = x.float() + fx
y_ = y.float() + fy
x_ = torch.clamp(x_ + 0.5, 0, w - 1).long()
y_ = torch.clamp(y_ + 0.5, 0, h - 1).long()
grid_fx = fx[y_grid, x_grid]
grid_fy = fy[y_grid, x_grid]
x_grid_ = x_[y_grid, x_grid]
y_grid_ = y_[y_grid, x_grid]
grid_fx_ = F10_tensor[0, 0, y_grid_, x_grid_]
grid_fy_ = F10_tensor[0, 1, y_grid_, x_grid_]
sum_x = grid_fx + grid_fx_
sum_y = grid_fy + grid_fy_
distance = torch.sqrt(sum_x**2 + sum_y**2)
fx_len = torch.abs(grid_fx) * scale_x
fy_len = torch.abs(grid_fy) * scale_y
ori_len = torch.where(fx_len > fy_len, fx_len, fy_len)
thres = torch.clamp(0.1 * ori_len + 4, 5, 14)
# compare pix diff
ori_img = img0_gray
ref_img = img1_gray[:, :, y_, x_]
img_diff = ori_img.float() - ref_img.float()
img_diff = torch.abs(img_diff)
kernel = np.ones([8, 8], np.float) / 64
kernel = torch.FloatTensor(kernel).to(device).unsqueeze(0).unsqueeze(0)
diff = F.conv2d(img_diff, kernel, padding=4)
diff = diff[0, 0, y_grid, x_grid]
index = (distance > thres) * (diff > 5)
if index.sum().float() / distance.numel() > 0.5:
scene_change = True
return scene_change

View File

@@ -0,0 +1,96 @@
# The implementation is adopted from RAFT,
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
import numpy as np
import torch
import torch.nn.functional as F
from scipy import interpolate
class InputPadder:
""" Pads images such that dimensions are divisible by 8 """
def __init__(self, dims, mode='sintel'):
self.ht, self.wd = dims[-2:]
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
if mode == 'sintel':
self._pad = [
pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2,
pad_ht - pad_ht // 2
]
else:
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
def pad(self, *inputs):
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
def unpad(self, x):
ht, wd = x.shape[-2:]
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
return x[..., c[0]:c[1], c[2]:c[3]]
def forward_interpolate(flow):
flow = flow.detach().cpu().numpy()
dx, dy = flow[0], flow[1]
ht, wd = dx.shape
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
x1 = x0 + dx
y1 = y0 + dy
x1 = x1.reshape(-1)
y1 = y1.reshape(-1)
dx = dx.reshape(-1)
dy = dy.reshape(-1)
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
x1 = x1[valid]
y1 = y1[valid]
dx = dx[valid]
dy = dy[valid]
flow_x = interpolate.griddata((x1, y1),
dx, (x0, y0),
method='nearest',
fill_value=0)
flow_y = interpolate.griddata((x1, y1),
dy, (x0, y0),
method='nearest',
fill_value=0)
flow = np.stack([flow_x, flow_y], axis=0)
return torch.from_numpy(flow).float()
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1, 1], dim=-1)
xgrid = 2 * xgrid / (W - 1) - 1
ygrid = 2 * ygrid / (H - 1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(
torch.arange(ht, device=device), torch.arange(wd, device=device))
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def upflow8(flow, mode='bilinear'):
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
return 8 * F.interpolate(
flow, size=new_size, mode=mode, align_corners=True)

View File

@@ -0,0 +1,23 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .video_frame_interpolation_dataset import VideoFrameInterpolationDataset
else:
_import_structure = {
'video_frame_interpolation_dataset':
['VideoFrameInterpolationDataset'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,41 @@
# ------------------------------------------------------------------------
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
# Copyright 2018-2020 BasicSR Authors
# ------------------------------------------------------------------------
import cv2
import torch
import torch.nn.functional as F
def img2tensor(imgs, bgr2rgb=True, float32=True):
"""Numpy array to tensor.
Args:
imgs (list[ndarray] | ndarray): Input images.
bgr2rgb (bool): Whether to change bgr to rgb.
float32 (bool): Whether to change to float32.
Returns:
list[tensor] | tensor: Tensor images. If returned results only have
one element, just return tensor.
"""
def _totensor(img, bgr2rgb, float32):
if img.shape[2] == 3 and bgr2rgb:
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = torch.from_numpy(img.transpose(2, 0, 1))
if float32:
img = img.float()
return img
if isinstance(imgs, list):
return [_totensor(img, bgr2rgb, float32) for img in imgs]
else:
return _totensor(imgs, bgr2rgb, float32)
def img_padding(img_tensor, height, width, pad_num=32):
ph = ((height - 1) // pad_num + 1) * pad_num
pw = ((width - 1) // pad_num + 1) * pad_num
padding = (0, pw - width, 0, ph - height)
img_tensor = F.pad(img_tensor, padding)
return img_tensor

View File

@@ -0,0 +1,54 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from collections import defaultdict
import cv2
import numpy as np
import torch
from modelscope.metainfo import Models
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
TorchTaskDataset
from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
img2tensor, img_padding)
from modelscope.utils.constant import Tasks
def default_loader(path):
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
@TASK_DATASETS.register_module(
Tasks.video_frame_interpolation,
module_name=Models.video_frame_interpolation)
class VideoFrameInterpolationDataset(TorchTaskDataset):
"""Dataset for video frame-interpolation.
"""
def __init__(self, dataset, opt):
self.dataset = dataset
self.opt = opt
def __len__(self):
return len(self.dataset)
def __getitem__(self, index):
# Load frames. Dimension order: HWC; channel order: BGR;
# image range: [0, 1], float32
item_dict = self.dataset[index]
img0 = default_loader(item_dict['Input1:FILE'])
img1 = default_loader(item_dict['Input2:FILE'])
img2 = default_loader(item_dict['Input3:FILE'])
img3 = default_loader(item_dict['Input4:FILE'])
gt = default_loader(item_dict['Output:FILE'])
img0, img1, img2, img3, gt = img2tensor([img0, img1, img2, img3, gt],
bgr2rgb=False,
float32=True)
imgs = torch.cat((img0, img1, img2, img3), dim=0)
height, width = imgs.size(1), imgs.size(2)
imgs = img_padding(imgs, height, width, pad_num=32)
return {'input': imgs, 'target': gt / 255.0}

View File

@@ -305,6 +305,7 @@ TASK_OUTPUTS = {
# video editing task result for a single video
# {"output_video": "path_to_rendered_video"}
Tasks.video_frame_interpolation: [OutputKeys.OUTPUT_VIDEO],
Tasks.video_super_resolution: [OutputKeys.OUTPUT_VIDEO],
# live category recognition result for single video

View File

@@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
'damo/cv_video-inpainting'),
Tasks.video_human_matting: (Pipelines.video_human_matting,
'damo/cv_effnetv2_video-human-matting'),
Tasks.video_frame_interpolation:
(Pipelines.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation'),
Tasks.human_wholebody_keypoint:
(Pipelines.human_wholebody_keypoint,
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
@@ -247,9 +250,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
Tasks.translation_evaluation:
(Pipelines.translation_evaluation,
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
Tasks.video_object_segmentation:
(Pipelines.video_object_segmentation,
'damo/cv_rdevos_video-object-segmentation'),
Tasks.video_object_segmentation: (
Pipelines.video_object_segmentation,
'damo/cv_rdevos_video-object-segmentation'),
Tasks.video_multi_object_tracking: (
Pipelines.video_multi_object_tracking,
'damo/cv_yolov5_video-multi-object-tracking_fairmot'),

View File

@@ -68,6 +68,7 @@ if TYPE_CHECKING:
from .hand_static_pipeline import HandStaticPipeline
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline
from .image_skychange_pipeline import ImageSkychangePipeline
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
from .video_stabilization_pipeline import VideoStabilizationPipeline
@@ -164,6 +165,9 @@ else:
'language_guided_video_summarization_pipeline': [
'LanguageGuidedVideoSummarizationPipeline'
],
'video_frame_interpolation_pipeline': [
'VideoFrameInterpolationPipeline'
],
'image_skychange_pipeline': ['ImageSkychangePipeline'],
'video_object_segmentation_pipeline': [
'VideoObjectSegmentationPipeline'

View File

@@ -0,0 +1,613 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import math
import os
import os.path as osp
import subprocess
import tempfile
from typing import Any, Dict, Optional, Union
import cv2
import numpy as np
import torch
import torch.nn.functional as F
from torchvision.utils import make_grid
from modelscope.metainfo import Pipelines
from modelscope.models.cv.video_frame_interpolation.utils.scene_change_detection import \
do_scene_detect
from modelscope.models.cv.video_frame_interpolation.VFINet_for_video_frame_interpolation import \
VFINetForVideoFrameInterpolation
from modelscope.outputs import OutputKeys
from modelscope.pipelines.base import Input, Pipeline
from modelscope.pipelines.builder import PIPELINES
from modelscope.preprocessors import LoadImage
from modelscope.preprocessors.cv import VideoReader
from modelscope.utils.config import Config
from modelscope.utils.constant import ModelFile, Tasks
from modelscope.utils.logger import get_logger
VIDEO_EXTENSIONS = ('.mp4', '.mov')
logger = get_logger()
def img_trans(img_tensor): # in format of RGB
img_tensor = img_tensor / 255.0
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1,
1).type_as(img_tensor)
img_tensor -= mean
return img_tensor
def add_mean(x):
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x)
return x + mean
def img_padding(img_tensor, height, width, pad_num=32):
ph = ((height - 1) // pad_num + 1) * pad_num
pw = ((width - 1) // pad_num + 1) * pad_num
padding = (0, pw - width, 0, ph - height)
img_tensor = F.pad(img_tensor, padding)
return img_tensor
def do_inference_lowers(flow_10,
flow_12,
flow_21,
flow_23,
img1,
img2,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=False):
# given frame1, frame2 and optical flow, predict frame_t
if start_end_flag:
read_count -= 1
else:
read_count -= 2
while inter_count <= read_count:
t = inter_count + 1 - read_count
t = round(t, 2)
if (t - 0) < delta / 2:
output = img1
elif (1 - t) < delta / 2:
output = img2
else:
output = inter_model(flow_10, flow_12, flow_21, flow_23, img1,
img2, t)
output = 255 * add_mean(output)
outputs.append(output)
inter_count += delta
return outputs, inter_count
def do_inference_highers(flow_10,
flow_12,
flow_21,
flow_23,
img1,
img2,
img1_up,
img2_up,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=False):
# given frame1, frame2 and optical flow, predict frame_t. For videos with a resolution of 2k and above
if start_end_flag:
read_count -= 1
else:
read_count -= 2
while inter_count <= read_count:
t = inter_count + 1 - read_count
t = round(t, 2)
if (t - 0) < delta / 2:
output = img1_up
elif (1 - t) < delta / 2:
output = img2_up
else:
output = inter_model(flow_10, flow_12, flow_21, flow_23, img1,
img2, img1_up, img2_up, t)
output = 255 * add_mean(output)
outputs.append(output)
inter_count += delta
return outputs, inter_count
def inference_lowers(flow_model, refine_model, inter_model, video_len,
read_count, inter_count, delta, scene_change_flag,
img_tensor_list, img_ori_list, inputs, outputs):
# given a video with a resolution less than 2k and output fps, execute the video frame interpolation function.
height, width = inputs[read_count].size(2), inputs[read_count].size(3)
# We use four consecutive frames to do frame interpolation. flow_10 represents
# optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and
# flow_23.
flow_10 = None
flow_12 = None
flow_21 = None
flow_23 = None
with torch.no_grad():
while (read_count < video_len):
img = inputs[read_count]
img = img_padding(img, height, width)
img_ori_list.append(img)
img_tensor_list.append(img_trans(img))
read_count += 1
if len(img_tensor_list) == 2:
img0 = img_tensor_list[0]
img1 = img_tensor_list[1]
img0_ori = img_ori_list[0]
img1_ori = img_ori_list[1]
_, flow_01_up = flow_model(
img0_ori, img1_ori, iters=12, test_mode=True)
_, flow_10_up = flow_model(
img1_ori, img0_ori, iters=12, test_mode=True)
flow_01, flow_10 = refine_model(img0, img1, flow_01_up,
flow_10_up, 2)
scene_change_flag[0] = do_scene_detect(
flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height,
0:width],
img_ori_list[0][:, :, 0:height, 0:width],
img_ori_list[1][:, :, 0:height, 0:width])
if scene_change_flag[0]:
outputs, inter_count = do_inference_lowers(
None,
None,
None,
None,
img0,
img1,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
else:
outputs, inter_count = do_inference_lowers(
None,
flow_01,
flow_10,
None,
img0,
img1,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
if len(img_tensor_list) == 4:
if flow_12 is None or flow_21 is None:
img2 = img_tensor_list[2]
img2_ori = img_ori_list[2]
_, flow_12_up = flow_model(
img1_ori, img2_ori, iters=12, test_mode=True)
_, flow_21_up = flow_model(
img2_ori, img1_ori, iters=12, test_mode=True)
flow_12, flow_21 = refine_model(img1, img2, flow_12_up,
flow_21_up, 2)
scene_change_flag[1] = do_scene_detect(
flow_12[:, :, 0:height,
0:width], flow_21[:, :, 0:height, 0:width],
img_ori_list[1][:, :, 0:height, 0:width],
img_ori_list[2][:, :, 0:height, 0:width])
img3 = img_tensor_list[3]
img3_ori = img_ori_list[3]
_, flow_23_up = flow_model(
img2_ori, img3_ori, iters=12, test_mode=True)
_, flow_32_up = flow_model(
img3_ori, img2_ori, iters=12, test_mode=True)
flow_23, flow_32 = refine_model(img2, img3, flow_23_up,
flow_32_up, 2)
scene_change_flag[2] = do_scene_detect(
flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height,
0:width],
img_ori_list[2][:, :, 0:height, 0:width],
img_ori_list[3][:, :, 0:height, 0:width])
if scene_change_flag[1]:
outputs, inter_count = do_inference_lowers(
None, None, None, None, img1, img2, inter_model,
read_count, inter_count, delta, outputs)
elif scene_change_flag[0] or scene_change_flag[2]:
outputs, inter_count = do_inference_lowers(
None, flow_12, flow_21, None, img1, img2, inter_model,
read_count, inter_count, delta, outputs)
else:
outputs, inter_count = do_inference_lowers(
flow_10_up, flow_12, flow_21, flow_23_up, img1, img2,
inter_model, read_count, inter_count, delta, outputs)
img_tensor_list.pop(0)
img_ori_list.pop(0)
# for next group
img1 = img2
img2 = img3
img1_ori = img2_ori
img2_ori = img3_ori
flow_10 = flow_21
flow_12 = flow_23
flow_21 = flow_32
flow_10_up = flow_21_up
flow_12_up = flow_23_up
flow_21_up = flow_32_up
# save scene change flag for next group
scene_change_flag[0] = scene_change_flag[1]
scene_change_flag[1] = scene_change_flag[2]
scene_change_flag[2] = False
if read_count > 0: # the last remaining 3 images
img_ori_list.pop(0)
img_tensor_list.pop(0)
assert (len(img_tensor_list) == 2)
if scene_change_flag[1]:
outputs, inter_count = do_inference_lowers(
None,
None,
None,
None,
img1,
img2,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
else:
outputs, inter_count = do_inference_lowers(
None,
flow_12,
flow_21,
None,
img1,
img2,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
return outputs
def inference_highers(flow_model, refine_model, inter_model, video_len,
read_count, inter_count, delta, scene_change_flag,
img_tensor_list, img_ori_list, inputs, outputs):
# given a video with a resolution of 2k or above and output fps, execute the video frame interpolation function.
if inputs[read_count].size(2) % 2 != 0 or inputs[read_count].size(
3) % 2 != 0:
raise RuntimeError('Video width and height must be even')
height, width = inputs[read_count].size(2) // 2, inputs[read_count].size(
3) // 2
# We use four consecutive frames to do frame interpolation. flow_10 represents
# optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and
# flow_23.
flow_10 = None
flow_12 = None
flow_21 = None
flow_23 = None
img_up_list = []
with torch.no_grad():
while (read_count < video_len):
img_up = inputs[read_count]
img_up = img_padding(img_up, height * 2, width * 2, pad_num=64)
img = F.interpolate(
img_up, scale_factor=0.5, mode='bilinear', align_corners=False)
img_up_list.append(img_trans(img_up))
img_ori_list.append(img)
img_tensor_list.append(img_trans(img))
read_count += 1
if len(img_tensor_list) == 2:
img0 = img_tensor_list[0]
img1 = img_tensor_list[1]
img0_ori = img_ori_list[0]
img1_ori = img_ori_list[1]
img0_up = img_up_list[0]
img1_up = img_up_list[1]
_, flow_01_up = flow_model(
img0_ori, img1_ori, iters=12, test_mode=True)
_, flow_10_up = flow_model(
img1_ori, img0_ori, iters=12, test_mode=True)
flow_01, flow_10 = refine_model(img0, img1, flow_01_up,
flow_10_up, 2)
scene_change_flag[0] = do_scene_detect(
flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height,
0:width],
img_ori_list[0][:, :, 0:height, 0:width],
img_ori_list[1][:, :, 0:height, 0:width])
if scene_change_flag[0]:
outputs, inter_count = do_inference_highers(
None,
None,
None,
None,
img0,
img1,
img0_up,
img1_up,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
else:
outputs, inter_count = do_inference_highers(
None,
flow_01,
flow_10,
None,
img0,
img1,
img0_up,
img1_up,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
if len(img_tensor_list) == 4:
if flow_12 is None or flow_21 is None:
img2 = img_tensor_list[2]
img2_ori = img_ori_list[2]
img2_up = img_up_list[2]
_, flow_12_up = flow_model(
img1_ori, img2_ori, iters=12, test_mode=True)
_, flow_21_up = flow_model(
img2_ori, img1_ori, iters=12, test_mode=True)
flow_12, flow_21 = refine_model(img1, img2, flow_12_up,
flow_21_up, 2)
scene_change_flag[1] = do_scene_detect(
flow_12[:, :, 0:height,
0:width], flow_21[:, :, 0:height, 0:width],
img_ori_list[1][:, :, 0:height, 0:width],
img_ori_list[2][:, :, 0:height, 0:width])
img3 = img_tensor_list[3]
img3_ori = img_ori_list[3]
img3_up = img_up_list[3]
_, flow_23_up = flow_model(
img2_ori, img3_ori, iters=12, test_mode=True)
_, flow_32_up = flow_model(
img3_ori, img2_ori, iters=12, test_mode=True)
flow_23, flow_32 = refine_model(img2, img3, flow_23_up,
flow_32_up, 2)
scene_change_flag[2] = do_scene_detect(
flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height,
0:width],
img_ori_list[2][:, :, 0:height, 0:width],
img_ori_list[3][:, :, 0:height, 0:width])
if scene_change_flag[1]:
outputs, inter_count = do_inference_highers(
None, None, None, None, img1, img2, img1_up, img2_up,
inter_model, read_count, inter_count, delta, outputs)
elif scene_change_flag[0] or scene_change_flag[2]:
outputs, inter_count = do_inference_highers(
None, flow_12, flow_21, None, img1, img2, img1_up,
img2_up, inter_model, read_count, inter_count, delta,
outputs)
else:
outputs, inter_count = do_inference_highers(
flow_10_up, flow_12, flow_21, flow_23_up, img1, img2,
img1_up, img2_up, inter_model, read_count, inter_count,
delta, outputs)
img_up_list.pop(0)
img_tensor_list.pop(0)
img_ori_list.pop(0)
# for next group
img1 = img2
img2 = img3
img1_ori = img2_ori
img2_ori = img3_ori
img1_up = img2_up
img2_up = img3_up
flow_10 = flow_21
flow_12 = flow_23
flow_21 = flow_32
flow_10_up = flow_21_up
flow_12_up = flow_23_up
flow_21_up = flow_32_up
# save scene change flag for next group
scene_change_flag[0] = scene_change_flag[1]
scene_change_flag[1] = scene_change_flag[2]
scene_change_flag[2] = False
if read_count > 0: # the last remaining 3 images
img_ori_list.pop(0)
img_tensor_list.pop(0)
assert (len(img_tensor_list) == 2)
if scene_change_flag[1]:
outputs, inter_count = do_inference_highers(
None,
None,
None,
None,
img1,
img2,
img1_up,
img2_up,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
else:
outputs, inter_count = do_inference_highers(
None,
flow_12,
flow_21,
None,
img1,
img2,
img1_up,
img2_up,
inter_model,
read_count,
inter_count,
delta,
outputs,
start_end_flag=True)
return outputs
def convert(param):
return {
k.replace('module.', ''): v
for k, v in param.items() if 'module.' in k
}
__all__ = ['VideoFrameInterpolationPipeline']
@PIPELINES.register_module(
Tasks.video_frame_interpolation,
module_name=Pipelines.video_frame_interpolation)
class VideoFrameInterpolationPipeline(Pipeline):
""" Video Frame Interpolation Pipeline.
Example:
```python
>>> from modelscope.pipelines import pipeline
>>> from modelscope.utils.constant import Tasks
>>> from modelscope.outputs import OutputKeys
>>> video = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_frame_interpolation_test.mp4'
>>> video_frame_interpolation_pipeline = pipeline(Tasks.video_frame_interpolation,
'damo/cv_raft_video-frame-interpolation')
>>> result = video_frame_interpolation_pipeline(video)[OutputKeys.OUTPUT_VIDEO]
>>> print('pipeline: the output video path is {}'.format(result))
```
"""
def __init__(self,
model: Union[VFINetForVideoFrameInterpolation, str],
preprocessor=None,
**kwargs):
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
if torch.cuda.is_available():
self._device = torch.device('cuda')
else:
self._device = torch.device('cpu')
self.net = self.model.model
self.net.to(self._device)
self.net.eval()
logger.info('load video frame-interpolation done')
def preprocess(self, input: Input, out_fps: float = 0) -> Dict[str, Any]:
# read images
file_extension = os.path.splitext(input)[1]
if file_extension in VIDEO_EXTENSIONS: # input is a video file
video_reader = VideoReader(input)
inputs = []
for frame in video_reader:
inputs.append(frame)
fps = video_reader.fps
elif file_extension == '': # input is a directory
inputs = []
input_paths = sorted(glob.glob(f'{input}/*'))
for input_path in input_paths:
img = LoadImage(input_path, mode='rgb')
inputs.append(img)
fps = 25 # default fps
else:
raise ValueError('"input" can only be a video or a directory.')
for i, img in enumerate(inputs):
img = torch.from_numpy(img.copy()).permute(2, 0, 1).float()
inputs[i] = img.unsqueeze(0)
if out_fps == 0:
out_fps = 2 * fps
return {'video': inputs, 'fps': fps, 'out_fps': out_fps}
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
inputs = input['video']
fps = input['fps']
out_fps = input['out_fps']
video_len = len(inputs)
flow_model = self.net.flownet
refine_model = self.net.internet.ifnet
read_count = 0
inter_count = 0
delta = fps / out_fps
scene_change_flag = [False, False, False]
img_tensor_list = []
img_ori_list = []
outputs = []
height, width = inputs[read_count].size(2), inputs[read_count].size(3)
if height >= 1440 or width >= 2560:
inter_model = self.net.internet_Ds.internet
outputs = inference_highers(flow_model, refine_model, inter_model,
video_len, read_count, inter_count,
delta, scene_change_flag,
img_tensor_list, img_ori_list, inputs,
outputs)
else:
inter_model = self.net.internet.internet
outputs = inference_lowers(flow_model, refine_model, inter_model,
video_len, read_count, inter_count,
delta, scene_change_flag,
img_tensor_list, img_ori_list, inputs,
outputs)
for i in range(len(outputs)):
outputs[i] = outputs[i][:, :, 0:height, 0:width]
return {'output': outputs, 'fps': out_fps}
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
output_video_path = kwargs.get('output_video', None)
demo_service = kwargs.get('demo_service', True)
if output_video_path is None:
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
h, w = inputs['output'][0].shape[-2:]
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter(output_video_path, fourcc,
inputs['fps'], (w, h))
for i in range(len(inputs['output'])):
img = inputs['output'][i]
img = img[0].permute(1, 2, 0).byte().cpu().numpy()
video_writer.write(img.astype(np.uint8))
video_writer.release()
if demo_service:
assert os.system(
'ffmpeg -version') == 0, 'ffmpeg is not installed correctly!'
output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
subprocess.call(convert_cmd, shell=True)
return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
else:
return {OutputKeys.OUTPUT_VIDEO: output_video_path}

View File

@@ -98,6 +98,7 @@ class CVTasks(object):
# video editing
video_inpainting = 'video-inpainting'
video_frame_interpolation = 'video-frame-interpolation'
video_stabilization = 'video-stabilization'
video_super_resolution = 'video-super-resolution'

View File

@@ -0,0 +1,52 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import sys
import unittest
from modelscope.hub.snapshot_download import snapshot_download
from modelscope.models import Model
from modelscope.outputs import OutputKeys
from modelscope.pipelines import pipeline
from modelscope.pipelines.cv import VideoFrameInterpolationPipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.demo_utils import DemoCompatibilityCheck
from modelscope.utils.test_utils import test_level
class VideoFrameInterpolationTest(unittest.TestCase, DemoCompatibilityCheck):
def setUp(self) -> None:
self.task = Tasks.video_frame_interpolation
self.model_id = 'damo/cv_raft_video-frame-interpolation'
self.test_video = 'data/test/videos/video_frame_interpolation_test.mp4'
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_by_direct_model_download(self):
cache_path = snapshot_download(self.model_id)
pipeline = VideoFrameInterpolationPipeline(cache_path)
pipeline.group_key = self.task
out_video_path = pipeline(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_model_from_modelhub(self):
pipeline_ins = pipeline(
task=Tasks.video_frame_interpolation, model=self.model_id)
out_video_path = pipeline_ins(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_default_model(self):
pipeline_ins = pipeline(task=Tasks.video_frame_interpolation)
out_video_path = pipeline_ins(
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
print('pipeline: the output video path is {}'.format(out_video_path))
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):
self.compatibility_check()
if __name__ == '__main__':
unittest.main()