mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 04:01:10 +01:00
submit video frame interpolation model
增加视频插帧模型
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11188339
This commit is contained in:
3
data/test/videos/video_frame_interpolation_test.mp4
Normal file
3
data/test/videos/video_frame_interpolation_test.mp4
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:e97ff88d0af12f7dd3ef04ce50b87b51ffbb9a57dce81d2d518df4abd2fdb826
|
||||
size 3231793
|
||||
@@ -63,6 +63,7 @@ class Models(object):
|
||||
image_body_reshaping = 'image-body-reshaping'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_stabilization = 'video-stabilization'
|
||||
real_basicvsr = 'real-basicvsr'
|
||||
@@ -270,6 +271,7 @@ class Pipelines(object):
|
||||
referring_video_object_segmentation = 'referring-video-object-segmentation'
|
||||
image_skychange = 'image-skychange'
|
||||
video_human_matting = 'video-human-matting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_object_segmentation = 'video-object-segmentation'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'realbasicvsr-video-super-resolution'
|
||||
@@ -507,6 +509,8 @@ class Metrics(object):
|
||||
|
||||
# metrics for image denoise task
|
||||
image_denoise_metric = 'image-denoise-metric'
|
||||
# metrics for video frame-interpolation task
|
||||
video_frame_interpolation_metric = 'video-frame-interpolation-metric'
|
||||
# metrics for real-world video super-resolution task
|
||||
video_super_resolution_metric = 'video-super-resolution-metric'
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ if TYPE_CHECKING:
|
||||
from .bleu_metric import BleuMetric
|
||||
from .image_inpainting_metric import ImageInpaintingMetric
|
||||
from .referring_video_object_segmentation_metric import ReferringVideoObjectSegmentationMetric
|
||||
from .video_frame_interpolation_metric import VideoFrameInterpolationMetric
|
||||
from .video_stabilization_metric import VideoStabilizationMetric
|
||||
from .video_super_resolution_metric.video_super_resolution_metric import VideoSuperResolutionMetric
|
||||
from .ppl_metric import PplMetric
|
||||
@@ -46,6 +47,7 @@ else:
|
||||
'bleu_metric': ['BleuMetric'],
|
||||
'referring_video_object_segmentation_metric':
|
||||
['ReferringVideoObjectSegmentationMetric'],
|
||||
'video_frame_interpolation_metric': ['VideoFrameInterpolationMetric'],
|
||||
'video_stabilization_metric': ['VideoStabilizationMetric'],
|
||||
'ppl_metric': ['PplMetric'],
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ class MetricKeys(object):
|
||||
RECALL = 'recall'
|
||||
PSNR = 'psnr'
|
||||
SSIM = 'ssim'
|
||||
LPIPS = 'lpips'
|
||||
NIQE = 'niqe'
|
||||
AVERAGE_LOSS = 'avg_loss'
|
||||
FScore = 'fscore'
|
||||
@@ -53,6 +54,8 @@ task_default_metrics = {
|
||||
Tasks.image_inpainting: [Metrics.image_inpainting_metric],
|
||||
Tasks.referring_video_object_segmentation:
|
||||
[Metrics.referring_video_object_segmentation_metric],
|
||||
Tasks.video_frame_interpolation:
|
||||
[Metrics.video_frame_interpolation_metric],
|
||||
Tasks.video_stabilization: [Metrics.video_stabilization_metric],
|
||||
}
|
||||
|
||||
|
||||
172
modelscope/metrics/video_frame_interpolation_metric.py
Normal file
172
modelscope/metrics/video_frame_interpolation_metric.py
Normal file
@@ -0,0 +1,172 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# ------------------------------------------------------------------------
|
||||
import math
|
||||
from math import exp
|
||||
from typing import Dict
|
||||
|
||||
import lpips
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.metainfo import Metrics
|
||||
from modelscope.metrics.base import Metric
|
||||
from modelscope.metrics.builder import METRICS, MetricKeys
|
||||
from modelscope.utils.registry import default_group
|
||||
|
||||
|
||||
@METRICS.register_module(
|
||||
group_key=default_group,
|
||||
module_name=Metrics.video_frame_interpolation_metric)
|
||||
class VideoFrameInterpolationMetric(Metric):
|
||||
"""The metric computation class for video frame interpolation,
|
||||
which will return PSNR, SSIM and LPIPS.
|
||||
"""
|
||||
pred_name = 'pred'
|
||||
label_name = 'target'
|
||||
|
||||
def __init__(self):
|
||||
super(VideoFrameInterpolationMetric, self).__init__()
|
||||
self.preds = []
|
||||
self.labels = []
|
||||
self.loss_fn_alex = lpips.LPIPS(net='alex').cuda()
|
||||
|
||||
def add(self, outputs: Dict, inputs: Dict):
|
||||
ground_truths = outputs[VideoFrameInterpolationMetric.label_name]
|
||||
eval_results = outputs[VideoFrameInterpolationMetric.pred_name]
|
||||
self.preds.append(eval_results)
|
||||
self.labels.append(ground_truths)
|
||||
|
||||
def evaluate(self):
|
||||
psnr_list, ssim_list, lpips_list = [], [], []
|
||||
with torch.no_grad():
|
||||
for (pred, label) in zip(self.preds, self.labels):
|
||||
# norm to 0-1
|
||||
height, width = label.size(2), label.size(3)
|
||||
pred = pred[:, :, 0:height, 0:width]
|
||||
|
||||
psnr_list.append(calculate_psnr(label, pred))
|
||||
ssim_list.append(calculate_ssim(label, pred))
|
||||
lpips_list.append(
|
||||
calculate_lpips(label, pred, self.loss_fn_alex))
|
||||
|
||||
return {
|
||||
MetricKeys.PSNR: np.mean(psnr_list),
|
||||
MetricKeys.SSIM: np.mean(ssim_list),
|
||||
MetricKeys.LPIPS: np.mean(lpips_list)
|
||||
}
|
||||
|
||||
|
||||
def gaussian(window_size, sigma):
|
||||
gauss = torch.Tensor([
|
||||
exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
|
||||
for x in range(window_size)
|
||||
])
|
||||
return gauss / gauss.sum()
|
||||
|
||||
|
||||
def create_window_3d(window_size, channel=1, device=None):
|
||||
_1D_window = gaussian(window_size, 1.5).unsqueeze(1)
|
||||
_2D_window = _1D_window.mm(_1D_window.t())
|
||||
_3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
|
||||
window = _3D_window.expand(1, channel, window_size, window_size,
|
||||
window_size).contiguous().to(device)
|
||||
return window
|
||||
|
||||
|
||||
def calculate_psnr(img1, img2):
|
||||
psnr = -10 * math.log10(
|
||||
torch.mean((img1[0] - img2[0]) * (img1[0] - img2[0])).cpu().data)
|
||||
return psnr
|
||||
|
||||
|
||||
def calculate_ssim(img1,
|
||||
img2,
|
||||
window_size=11,
|
||||
window=None,
|
||||
size_average=True,
|
||||
full=False,
|
||||
val_range=None):
|
||||
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
|
||||
if val_range is None:
|
||||
if torch.max(img1) > 128:
|
||||
max_val = 255
|
||||
else:
|
||||
max_val = 1
|
||||
|
||||
if torch.min(img1) < -0.5:
|
||||
min_val = -1
|
||||
else:
|
||||
min_val = 0
|
||||
L = max_val - min_val
|
||||
else:
|
||||
L = val_range
|
||||
|
||||
padd = 0
|
||||
(_, _, height, width) = img1.size()
|
||||
if window is None:
|
||||
real_size = min(window_size, height, width)
|
||||
window = create_window_3d(
|
||||
real_size, channel=1, device=img1.device).to(img1.device)
|
||||
# Channel is set to 1 since we consider color images as volumetric images
|
||||
|
||||
img1 = img1.unsqueeze(1)
|
||||
img2 = img2.unsqueeze(1)
|
||||
|
||||
mu1 = F.conv3d(
|
||||
F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1)
|
||||
mu2 = F.conv3d(
|
||||
F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1)
|
||||
|
||||
mu1_sq = mu1.pow(2)
|
||||
mu2_sq = mu2.pow(2)
|
||||
mu1_mu2 = mu1 * mu2
|
||||
|
||||
sigma1_sq = F.conv3d(
|
||||
F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1) - mu1_sq
|
||||
sigma2_sq = F.conv3d(
|
||||
F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1) - mu2_sq
|
||||
sigma12 = F.conv3d(
|
||||
F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
|
||||
window,
|
||||
padding=padd,
|
||||
groups=1) - mu1_mu2
|
||||
|
||||
C1 = (0.01 * L)**2
|
||||
C2 = (0.03 * L)**2
|
||||
|
||||
v1 = 2.0 * sigma12 + C2
|
||||
v2 = sigma1_sq + sigma2_sq + C2
|
||||
cs = torch.mean(v1 / v2) # contrast sensitivity
|
||||
|
||||
ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
|
||||
|
||||
if size_average:
|
||||
ret = ssim_map.mean()
|
||||
else:
|
||||
ret = ssim_map.mean(1).mean(1).mean(1)
|
||||
|
||||
if full:
|
||||
return ret, cs
|
||||
return ret.cpu()
|
||||
|
||||
|
||||
def calculate_lpips(img1, img2, loss_fn_alex):
|
||||
img1 = img1 * 2 - 1
|
||||
img2 = img2 * 2 - 1
|
||||
|
||||
d = loss_fn_alex(img1, img2)
|
||||
return d.cpu().item()
|
||||
@@ -15,8 +15,8 @@ from . import (action_recognition, animal_recognition, body_2d_keypoints,
|
||||
pointcloud_sceneflow_estimation, product_retrieval_embedding,
|
||||
realtime_object_detection, referring_video_object_segmentation,
|
||||
salient_detection, shop_segmentation, super_resolution,
|
||||
video_object_segmentation, video_single_object_tracking,
|
||||
video_stabilization, video_summarization,
|
||||
video_super_resolution, virual_tryon)
|
||||
video_frame_interpolation, video_object_segmentation,
|
||||
video_single_object_tracking, video_stabilization,
|
||||
video_summarization, video_super_resolution, virual_tryon)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.video_frame_interpolation.flow_model.raft import RAFT
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \
|
||||
IFNet
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.refinenet_arch import (
|
||||
InterpNet, InterpNetDs)
|
||||
|
||||
|
||||
class VFINet(nn.Module):
|
||||
|
||||
def __init__(self, args, Ds_flag=False):
|
||||
super(VFINet, self).__init__()
|
||||
self.flownet = RAFT(args)
|
||||
self.internet = InterpNet()
|
||||
if Ds_flag:
|
||||
self.internet_Ds = InterpNetDs()
|
||||
|
||||
def img_trans(self, img_tensor): # in format of RGB
|
||||
img_tensor = img_tensor / 255.0
|
||||
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1,
|
||||
1).type_as(img_tensor)
|
||||
img_tensor -= mean
|
||||
return img_tensor
|
||||
|
||||
def add_mean(self, x):
|
||||
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x)
|
||||
return x + mean
|
||||
|
||||
def forward(self, imgs, timestep=0.5):
|
||||
self.flownet.eval()
|
||||
self.internet.eval()
|
||||
with torch.no_grad():
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:6]
|
||||
img2 = imgs[:, 6:9]
|
||||
img3 = imgs[:, 9:12]
|
||||
|
||||
_, F10_up = self.flownet(img1, img0, iters=12, test_mode=True)
|
||||
_, F12_up = self.flownet(img1, img2, iters=12, test_mode=True)
|
||||
_, F21_up = self.flownet(img2, img1, iters=12, test_mode=True)
|
||||
_, F23_up = self.flownet(img2, img3, iters=12, test_mode=True)
|
||||
|
||||
img1 = self.img_trans(img1.clone())
|
||||
img2 = self.img_trans(img2.clone())
|
||||
|
||||
It_warp = self.internet(
|
||||
img1, img2, F10_up, F12_up, F21_up, F23_up, timestep=timestep)
|
||||
It_warp = self.add_mean(It_warp)
|
||||
|
||||
return It_warp
|
||||
@@ -0,0 +1,98 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
import torch.cuda
|
||||
import torch.nn.functional as F
|
||||
from torch.nn.parallel import DataParallel, DistributedDataParallel
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Tensor
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
# from modelscope.models.cv.video_super_resolution.common import charbonnier_loss
|
||||
from modelscope.models.cv.video_frame_interpolation.VFINet_arch import VFINet
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
__all__ = ['VFINetForVideoFrameInterpolation']
|
||||
|
||||
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace('module.', ''): v
|
||||
for k, v in param.items() if 'module.' in k
|
||||
}
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.video_frame_interpolation,
|
||||
module_name=Models.video_frame_interpolation)
|
||||
class VFINetForVideoFrameInterpolation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, *args, **kwargs):
|
||||
"""initialize the video frame-interpolation model from the `model_dir` path.
|
||||
|
||||
Args:
|
||||
model_dir (str): the model path.
|
||||
|
||||
"""
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
|
||||
self.model_dir = model_dir
|
||||
self.config = Config.from_file(
|
||||
os.path.join(self.model_dir, ModelFile.CONFIGURATION))
|
||||
|
||||
flownet_path = os.path.join(model_dir, 'raft-sintel.pt')
|
||||
internet_path = os.path.join(model_dir, 'interpnet.pt')
|
||||
|
||||
self.model = VFINet(self.config.model.network, Ds_flag=True)
|
||||
self._load_pretrained(flownet_path, internet_path)
|
||||
|
||||
def _load_pretrained(self, flownet_path, internet_path):
|
||||
state_dict_flownet = torch.load(
|
||||
flownet_path, map_location=self._device)
|
||||
state_dict_internet = torch.load(
|
||||
internet_path, map_location=self._device)
|
||||
|
||||
self.model.flownet.load_state_dict(
|
||||
convert(state_dict_flownet), strict=True)
|
||||
self.model.internet.load_state_dict(
|
||||
convert(state_dict_internet), strict=True)
|
||||
self.model.internet_Ds.load_state_dict(
|
||||
convert(state_dict_internet), strict=True)
|
||||
logger.info('load model done.')
|
||||
|
||||
def _inference_forward(self, input: Tensor) -> Dict[str, Tensor]:
|
||||
return {'output': self.model(input)}
|
||||
|
||||
def _evaluate_postprocess(self, input: Tensor,
|
||||
target: Tensor) -> Dict[str, list]:
|
||||
preds = self.model(input)
|
||||
del input
|
||||
torch.cuda.empty_cache()
|
||||
return {'pred': preds, 'target': target}
|
||||
|
||||
def forward(self, inputs: Dict[str,
|
||||
Tensor]) -> Dict[str, Union[list, Tensor]]:
|
||||
"""return the result by the model
|
||||
|
||||
Args:
|
||||
inputs (Tensor): the preprocessed data
|
||||
|
||||
Returns:
|
||||
Dict[str, Tensor]: results
|
||||
"""
|
||||
|
||||
if 'target' in inputs:
|
||||
return self._evaluate_postprocess(**inputs)
|
||||
else:
|
||||
return self._inference_forward(**inputs)
|
||||
20
modelscope/models/cv/video_frame_interpolation/__init__.py
Normal file
20
modelscope/models/cv/video_frame_interpolation/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .VFINet_arch import VFINet
|
||||
|
||||
else:
|
||||
_import_structure = {'VFINet_arch': ['VFINet']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
# The implementation is adopted from RAFT,
|
||||
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_frame_interpolation.utils.utils import (
|
||||
bilinear_sampler, coords_grid)
|
||||
|
||||
|
||||
class CorrBlock:
|
||||
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
self.corr_pyramid = []
|
||||
|
||||
# all pairs correlation
|
||||
corr = CorrBlock.corr(fmap1, fmap2)
|
||||
|
||||
batch, h1, w1, dim, h2, w2 = corr.shape
|
||||
corr = corr.reshape(batch * h1 * w1, dim, h2, w2)
|
||||
|
||||
self.corr_pyramid.append(corr)
|
||||
for i in range(self.num_levels - 1):
|
||||
corr = F.avg_pool2d(corr, 2, stride=2)
|
||||
self.corr_pyramid.append(corr)
|
||||
|
||||
def __call__(self, coords):
|
||||
r = self.radius
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
batch, h1, w1, _ = coords.shape
|
||||
|
||||
out_pyramid = []
|
||||
for i in range(self.num_levels):
|
||||
corr = self.corr_pyramid[i]
|
||||
dx = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
|
||||
dy = torch.linspace(-r, r, 2 * r + 1, device=coords.device)
|
||||
delta = torch.stack(torch.meshgrid(dy, dx), axis=-1)
|
||||
|
||||
centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i
|
||||
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
||||
coords_lvl = centroid_lvl + delta_lvl
|
||||
|
||||
corr = bilinear_sampler(corr, coords_lvl)
|
||||
corr = corr.view(batch, h1, w1, -1)
|
||||
out_pyramid.append(corr)
|
||||
|
||||
out = torch.cat(out_pyramid, dim=-1)
|
||||
return out.permute(0, 3, 1, 2).contiguous().float()
|
||||
|
||||
@staticmethod
|
||||
def corr(fmap1, fmap2):
|
||||
batch, dim, ht, wd = fmap1.shape
|
||||
fmap1 = fmap1.view(batch, dim, ht * wd)
|
||||
fmap2 = fmap2.view(batch, dim, ht * wd)
|
||||
|
||||
corr = torch.matmul(fmap1.transpose(1, 2), fmap2)
|
||||
corr = corr.view(batch, ht, wd, 1, ht, wd)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
|
||||
|
||||
class AlternateCorrBlock:
|
||||
|
||||
def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
|
||||
self.num_levels = num_levels
|
||||
self.radius = radius
|
||||
|
||||
self.pyramid = [(fmap1, fmap2)]
|
||||
for i in range(self.num_levels):
|
||||
fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
self.pyramid.append((fmap1, fmap2))
|
||||
|
||||
def __call__(self, coords):
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
B, H, W, _ = coords.shape
|
||||
dim = self.pyramid[0][0].shape[1]
|
||||
|
||||
corr_list = []
|
||||
for i in range(self.num_levels):
|
||||
r = self.radius
|
||||
fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous()
|
||||
fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous()
|
||||
|
||||
coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
|
||||
corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r)
|
||||
corr_list.append(corr.squeeze(1))
|
||||
|
||||
corr = torch.stack(corr_list, dim=1)
|
||||
corr = corr.reshape(B, -1, H, W)
|
||||
return corr / torch.sqrt(torch.tensor(dim).float())
|
||||
@@ -0,0 +1,288 @@
|
||||
# The implementation is adopted from RAFT,
|
||||
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ResidualBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
||||
self.norm3)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BottleneckBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(BottleneckBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_planes, planes // 4, kernel_size=1, padding=0)
|
||||
self.conv2 = nn.Conv2d(
|
||||
planes // 4, planes // 4, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv3 = nn.Conv2d(planes // 4, planes, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm2 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes // 4)
|
||||
self.norm3 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm2 = nn.BatchNorm2d(planes // 4)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm2 = nn.InstanceNorm2d(planes // 4)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm4 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
|
||||
self.norm4)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
y = self.relu(self.norm3(self.conv3(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=2)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m,
|
||||
(nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(
|
||||
self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class SmallEncoder(nn.Module):
|
||||
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(SmallEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(32)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 32
|
||||
self.layer1 = self._make_layer(32, stride=1)
|
||||
self.layer2 = self._make_layer(64, stride=2)
|
||||
self.layer3 = self._make_layer(96, stride=2)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(
|
||||
m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m,
|
||||
(nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = BottleneckBlock(
|
||||
self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.training and self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, [batch_dim, batch_dim], dim=0)
|
||||
|
||||
return x
|
||||
@@ -0,0 +1,157 @@
|
||||
# The implementation is adopted from RAFT,
|
||||
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_frame_interpolation.flow_model.corr import (
|
||||
AlternateCorrBlock, CorrBlock)
|
||||
from modelscope.models.cv.video_frame_interpolation.flow_model.extractor import (
|
||||
BasicEncoder, SmallEncoder)
|
||||
from modelscope.models.cv.video_frame_interpolation.flow_model.update import (
|
||||
BasicUpdateBlock, SmallUpdateBlock)
|
||||
from modelscope.models.cv.video_frame_interpolation.utils.utils import (
|
||||
bilinear_sampler, coords_grid, upflow8)
|
||||
|
||||
autocast = torch.cuda.amp.autocast
|
||||
|
||||
|
||||
class RAFT(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(RAFT, self).__init__()
|
||||
self.args = args
|
||||
|
||||
if args.small:
|
||||
self.hidden_dim = hdim = 96
|
||||
self.context_dim = cdim = 64
|
||||
self.args.corr_levels = 4
|
||||
self.args.corr_radius = 3
|
||||
|
||||
else:
|
||||
self.hidden_dim = hdim = 128
|
||||
self.context_dim = cdim = 128
|
||||
self.args.corr_levels = 4
|
||||
self.args.corr_radius = 4
|
||||
|
||||
if 'dropout' not in self.args:
|
||||
self.args.dropout = 0
|
||||
|
||||
if 'alternate_corr' not in self.args:
|
||||
self.args.alternate_corr = False
|
||||
|
||||
# feature network, context network, and update block
|
||||
if args.small:
|
||||
self.fnet = SmallEncoder(
|
||||
output_dim=128, norm_fn='instance', dropout=self.args.dropout)
|
||||
self.cnet = SmallEncoder(
|
||||
output_dim=hdim + cdim,
|
||||
norm_fn='none',
|
||||
dropout=self.args.dropout)
|
||||
self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
else:
|
||||
self.fnet = BasicEncoder(
|
||||
output_dim=256, norm_fn='instance', dropout=self.args.dropout)
|
||||
self.cnet = BasicEncoder(
|
||||
output_dim=hdim + cdim,
|
||||
norm_fn='batch',
|
||||
dropout=self.args.dropout)
|
||||
self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def initialize_flow(self, img):
|
||||
""" Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
|
||||
N, C, H, W = img.shape
|
||||
coords0 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||
coords1 = coords_grid(N, H // 8, W // 8, device=img.device)
|
||||
|
||||
# optical flow computed as difference: flow = coords1 - coords0
|
||||
return coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
mask = mask.view(N, 1, 9, 8, 8, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3, 3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, 8 * H, 8 * W)
|
||||
|
||||
def forward(self,
|
||||
image1,
|
||||
image2,
|
||||
iters=12,
|
||||
flow_init=None,
|
||||
upsample=True,
|
||||
test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
if self.args.alternate_corr:
|
||||
corr_fn = AlternateCorrBlock(
|
||||
fmap1, fmap2, radius=self.args.corr_radius)
|
||||
else:
|
||||
corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
|
||||
|
||||
# run the context network
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
cnet = self.cnet(image1)
|
||||
net, inp = torch.split(cnet, [hdim, cdim], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = torch.relu(inp)
|
||||
|
||||
coords0, coords1 = self.initialize_flow(image1)
|
||||
|
||||
if flow_init is not None:
|
||||
coords1 = coords1 + flow_init
|
||||
|
||||
flow_predictions = []
|
||||
for itr in range(iters):
|
||||
coords1 = coords1.detach()
|
||||
corr = corr_fn(coords1) # index correlation volume
|
||||
|
||||
flow = coords1 - coords0
|
||||
with autocast(enabled=self.args.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(
|
||||
net, inp, corr, flow)
|
||||
|
||||
# F(t+1) = F(t) + \Delta(t)
|
||||
coords1 = coords1 + delta_flow
|
||||
|
||||
# upsample predictions
|
||||
if up_mask is None:
|
||||
flow_up = upflow8(coords1 - coords0)
|
||||
else:
|
||||
flow_up = self.upsample_flow(coords1 - coords0, up_mask)
|
||||
|
||||
flow_predictions.append(flow_up)
|
||||
|
||||
if test_mode:
|
||||
return coords1 - coords0, flow_up
|
||||
|
||||
return flow_predictions
|
||||
@@ -0,0 +1,160 @@
|
||||
# The implementation is adopted from RAFT,
|
||||
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowHead(nn.Module):
|
||||
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(ConvGRU, self).__init__()
|
||||
self.convz = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convr = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
self.convq = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, 3, padding=1)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim=128, input_dim=192 + 128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
|
||||
self.convr1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
|
||||
self.convq1 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (1, 5), padding=(0, 2))
|
||||
|
||||
self.convz2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
|
||||
self.convr2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
|
||||
self.convq2 = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, (5, 1), padding=(2, 0))
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class SmallMotionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(SmallMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
|
||||
self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
|
||||
self.conv = nn.Conv2d(128, 80, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
cor_planes = args.corr_levels * (2 * args.corr_radius + 1)**2
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64 + 192, 128 - 2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class SmallUpdateBlock(nn.Module):
|
||||
|
||||
def __init__(self, args, hidden_dim=96):
|
||||
super(SmallUpdateBlock, self).__init__()
|
||||
self.encoder = SmallMotionEncoder(args)
|
||||
self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82 + 64)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
|
||||
|
||||
def forward(self, net, inp, corr, flow):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
return net, None, delta_flow
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
|
||||
def __init__(self, args, hidden_dim=128, input_dim=128):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
self.args = args
|
||||
self.encoder = BasicMotionEncoder(args)
|
||||
self.gru = SepConvGRU(
|
||||
hidden_dim=hidden_dim, input_dim=128 + hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, 64 * 9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
motion_features = self.encoder(flow, corr)
|
||||
inp = torch.cat([inp, motion_features], dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
||||
@@ -0,0 +1,434 @@
|
||||
# Part of the implementation is borrowed and modified from RIFE,
|
||||
# publicly available at https://github.com/megvii-research/ECCV2022-RIFE
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from timm.models.layers import trunc_normal_
|
||||
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.transformer_layers import (
|
||||
RTFL, PatchEmbed, PatchUnEmbed)
|
||||
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
|
||||
backwarp_tenGrid = {}
|
||||
|
||||
|
||||
def warp(tenInput, tenFlow):
|
||||
k = (str(tenFlow.device), str(tenFlow.size()))
|
||||
if k not in backwarp_tenGrid:
|
||||
tenHorizontal = torch.linspace(
|
||||
-1.0, 1.0, tenFlow.shape[3], device=device).view(
|
||||
1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1,
|
||||
tenFlow.shape[2], -1)
|
||||
tenVertical = torch.linspace(
|
||||
-1.0, 1.0, tenFlow.shape[2],
|
||||
device=device).view(1, 1, tenFlow.shape[2],
|
||||
1).expand(tenFlow.shape[0], -1, -1,
|
||||
tenFlow.shape[3])
|
||||
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical],
|
||||
1).to(device)
|
||||
|
||||
tmp1 = tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0)
|
||||
tmp2 = tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)
|
||||
tenFlow = torch.cat([tmp1, tmp2], 1)
|
||||
|
||||
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
|
||||
return torch.nn.functional.grid_sample(
|
||||
input=tenInput,
|
||||
grid=torch.clamp(g, -1, 1),
|
||||
mode='bilinear',
|
||||
padding_mode='border',
|
||||
align_corners=True)
|
||||
|
||||
|
||||
def conv_wo_act(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=True), )
|
||||
|
||||
|
||||
def conv(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=True), nn.PReLU(out_planes))
|
||||
|
||||
|
||||
def conv_bn(in_planes,
|
||||
out_planes,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(
|
||||
in_planes,
|
||||
out_planes,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=False), nn.BatchNorm2d(out_planes), nn.PReLU(out_planes))
|
||||
|
||||
|
||||
class TransModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
img_size=64,
|
||||
patch_size=1,
|
||||
embed_dim=64,
|
||||
depths=[[3, 3]],
|
||||
num_heads=[[2, 2]],
|
||||
window_size=4,
|
||||
mlp_ratio=2,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.1,
|
||||
norm_layer=nn.LayerNorm,
|
||||
ape=False,
|
||||
patch_norm=True,
|
||||
use_checkpoint=False,
|
||||
resi_connection='1conv',
|
||||
use_crossattn=[[[False, False, False, False],
|
||||
[True, True, True, True]]]):
|
||||
super(TransModel, self).__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.ape = ape
|
||||
self.patch_norm = patch_norm
|
||||
self.num_features = embed_dim
|
||||
self.mlp_ratio = mlp_ratio
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=embed_dim,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
patches_resolution = self.patch_embed.patches_resolution
|
||||
self.patches_resolution = patches_resolution
|
||||
|
||||
# merge non-overlapping patches into image
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=embed_dim,
|
||||
embed_dim=embed_dim,
|
||||
norm_layer=norm_layer if self.patch_norm else None)
|
||||
|
||||
# absolute position embedding
|
||||
if self.ape:
|
||||
self.absolute_pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches, embed_dim))
|
||||
trunc_normal_(self.absolute_pos_embed, std=.02)
|
||||
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr0 = [
|
||||
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[0]))
|
||||
] # stochastic depth decay rule
|
||||
|
||||
self.layers0 = nn.ModuleList()
|
||||
num_layers = len(depths[0])
|
||||
for i_layer in range(num_layers):
|
||||
layer = RTFL(
|
||||
dim=embed_dim,
|
||||
input_resolution=(patches_resolution[0],
|
||||
patches_resolution[1]),
|
||||
depth=depths[0][i_layer],
|
||||
num_heads=num_heads[0][i_layer],
|
||||
window_size=window_size,
|
||||
mlp_ratio=self.mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop_rate,
|
||||
attn_drop=attn_drop_rate,
|
||||
drop_path=dpr0[sum(depths[0][:i_layer]):sum(depths[0][:i_layer
|
||||
+ 1])],
|
||||
norm_layer=norm_layer,
|
||||
downsample=None,
|
||||
use_checkpoint=use_checkpoint,
|
||||
img_size=(img_size[0], img_size[1]),
|
||||
patch_size=patch_size,
|
||||
resi_connection=resi_connection,
|
||||
use_crossattn=use_crossattn[0][i_layer])
|
||||
self.layers0.append(layer)
|
||||
|
||||
self.norm = norm_layer(self.num_features)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
nn.init.constant_(m.bias, 0)
|
||||
nn.init.constant_(m.weight, 1.0)
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay(self):
|
||||
return {'absolute_pos_embed'}
|
||||
|
||||
@torch.jit.ignore
|
||||
def no_weight_decay_keywords(self):
|
||||
return {'relative_position_bias_table'}
|
||||
|
||||
def forward_features(self, x, layers):
|
||||
x_size = (x.shape[2], x.shape[3])
|
||||
x = self.patch_embed(x)
|
||||
if self.ape:
|
||||
x = x + self.absolute_pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
if isinstance(layers, nn.ModuleList):
|
||||
for layer in layers:
|
||||
x = layer(x, x_size)
|
||||
else:
|
||||
x = layers(x, x_size)
|
||||
|
||||
x = self.norm(x) # B L C
|
||||
x = self.patch_unembed(x, x_size)
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
out = self.forward_features(x, self.layers0)
|
||||
return out
|
||||
|
||||
|
||||
class IFBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, scale=1, c=64):
|
||||
super(IFBlock, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv0 = nn.Sequential(
|
||||
conv(in_planes, c // 2, 3, 2, 1),
|
||||
conv(c // 2, c, 3, 2, 1),
|
||||
conv(c, c, 3, 1, 1),
|
||||
)
|
||||
|
||||
self.trans = TransModel(
|
||||
img_size=(128 // scale, 128 // scale),
|
||||
patch_size=1,
|
||||
embed_dim=c,
|
||||
depths=[[3, 3]],
|
||||
num_heads=[[2, 2]])
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
conv(c, c, 3, 1, 1),
|
||||
conv(c, c, 3, 1, 1),
|
||||
)
|
||||
|
||||
self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1)
|
||||
|
||||
self.conv2 = nn.Conv2d(4, 4, 3, 1, 1)
|
||||
|
||||
def forward(self, x, flow0, flow1):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
flow0 = F.interpolate(
|
||||
flow0,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * (1. / self.scale)
|
||||
flow1 = F.interpolate(
|
||||
flow1,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * (1. / self.scale)
|
||||
|
||||
x = torch.cat((x, flow0, flow1), 1)
|
||||
|
||||
x = self.conv0(x)
|
||||
x = self.trans(x)
|
||||
x = self.conv1(x) + x
|
||||
|
||||
# upsample 2.0
|
||||
x = self.up(x)
|
||||
|
||||
# upsample 2.0
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(
|
||||
x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0
|
||||
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(
|
||||
flow,
|
||||
scale_factor=self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * self.scale
|
||||
|
||||
flow0 = flow[:, :2, :, :]
|
||||
flow1 = flow[:, 2:, :, :]
|
||||
|
||||
return flow0, flow1
|
||||
|
||||
|
||||
class IFBlock_wo_Swin(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, scale=1, c=64):
|
||||
super(IFBlock_wo_Swin, self).__init__()
|
||||
self.scale = scale
|
||||
self.conv0 = nn.Sequential(
|
||||
conv(in_planes, c // 2, 3, 2, 1),
|
||||
conv(c // 2, c, 3, 2, 1),
|
||||
)
|
||||
|
||||
self.convblock1 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c))
|
||||
self.convblock2 = nn.Sequential(conv(c, c), conv(c, c), conv(c, c))
|
||||
|
||||
self.up = nn.ConvTranspose2d(c, 4, 4, 2, 1)
|
||||
|
||||
self.conv2 = nn.Conv2d(4, 4, 3, 1, 1)
|
||||
|
||||
def forward(self, x, flow0, flow1):
|
||||
if self.scale != 1:
|
||||
x = F.interpolate(
|
||||
x,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
flow0 = F.interpolate(
|
||||
flow0,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * (1. / self.scale)
|
||||
flow1 = F.interpolate(
|
||||
flow1,
|
||||
scale_factor=1. / self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * (1. / self.scale)
|
||||
|
||||
x = torch.cat((x, flow0, flow1), 1)
|
||||
|
||||
x = self.conv0(x)
|
||||
x = self.convblock1(x) + x
|
||||
x = self.convblock2(x) + x
|
||||
# upsample 2.0
|
||||
x = self.up(x)
|
||||
|
||||
# upsample 2.0
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(
|
||||
x, scale_factor=2.0, mode='bilinear', align_corners=False) * 2.0
|
||||
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(
|
||||
flow,
|
||||
scale_factor=self.scale,
|
||||
mode='bilinear',
|
||||
align_corners=False) * self.scale
|
||||
|
||||
flow0 = flow[:, :2, :, :]
|
||||
flow1 = flow[:, 2:, :, :]
|
||||
|
||||
return flow0, flow1
|
||||
|
||||
|
||||
class IFNet(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(IFNet, self).__init__()
|
||||
self.block1 = IFBlock_wo_Swin(16, scale=4, c=128)
|
||||
self.block2 = IFBlock(16, scale=2, c=64)
|
||||
self.block3 = IFBlock(16, scale=1, c=32)
|
||||
|
||||
# flow0: flow from img0 to img1
|
||||
# flow1: flow from img1 to img0
|
||||
def forward(self, img0, img1, flow0, flow1, sc_mode=2):
|
||||
|
||||
if sc_mode == 0:
|
||||
sc = 0.25
|
||||
elif sc_mode == 1:
|
||||
sc = 0.5
|
||||
else:
|
||||
sc = 1
|
||||
|
||||
if sc != 1:
|
||||
img0_sc = F.interpolate(
|
||||
img0, scale_factor=sc, mode='bilinear', align_corners=False)
|
||||
img1_sc = F.interpolate(
|
||||
img1, scale_factor=sc, mode='bilinear', align_corners=False)
|
||||
flow0_sc = F.interpolate(
|
||||
flow0, scale_factor=sc, mode='bilinear',
|
||||
align_corners=False) * sc
|
||||
flow1_sc = F.interpolate(
|
||||
flow1, scale_factor=sc, mode='bilinear',
|
||||
align_corners=False) * sc
|
||||
else:
|
||||
img0_sc = img0
|
||||
img1_sc = img1
|
||||
flow0_sc = flow0
|
||||
flow1_sc = flow1
|
||||
|
||||
warped_img0 = warp(img1_sc, flow0_sc) # -> img0
|
||||
warped_img1 = warp(img0_sc, flow1_sc) # -> img1
|
||||
flow0_1, flow1_1 = self.block1(
|
||||
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1),
|
||||
flow0_sc, flow1_sc)
|
||||
F0_2 = (flow0_sc + flow0_1)
|
||||
F1_2 = (flow1_sc + flow1_1)
|
||||
|
||||
warped_img0 = warp(img1_sc, F0_2) # -> img0
|
||||
warped_img1 = warp(img0_sc, F1_2) # -> img1
|
||||
flow0_2, flow1_2 = self.block2(
|
||||
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), 1), F0_2,
|
||||
F1_2)
|
||||
F0_3 = (F0_2 + flow0_2)
|
||||
F1_3 = (F1_2 + flow1_2)
|
||||
|
||||
warped_img0 = warp(img1_sc, F0_3) # -> img0
|
||||
warped_img1 = warp(img0_sc, F1_3) # -> img1
|
||||
flow0_3, flow1_3 = self.block3(
|
||||
torch.cat((img0_sc, img1_sc, warped_img0, warped_img1), dim=1),
|
||||
F0_3, F1_3)
|
||||
flow_res_0 = flow0_1 + flow0_2 + flow0_3
|
||||
flow_res_1 = flow1_1 + flow1_2 + flow1_3
|
||||
|
||||
if sc != 1:
|
||||
flow_res_0 = F.interpolate(
|
||||
flow_res_0,
|
||||
scale_factor=1 / sc,
|
||||
mode='bilinear',
|
||||
align_corners=False) / sc
|
||||
flow_res_1 = F.interpolate(
|
||||
flow_res_1,
|
||||
scale_factor=1 / sc,
|
||||
mode='bilinear',
|
||||
align_corners=False) / sc
|
||||
|
||||
F0_4 = flow0 + flow_res_0
|
||||
F1_4 = flow1 + flow_res_1
|
||||
|
||||
return F0_4, F1_4
|
||||
@@ -0,0 +1,127 @@
|
||||
# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class down(nn.Module):
|
||||
|
||||
def __init__(self, inChannels, outChannels, filterSize):
|
||||
super(down, self).__init__()
|
||||
self.conv1 = nn.Conv2d(
|
||||
inChannels,
|
||||
outChannels,
|
||||
filterSize,
|
||||
stride=1,
|
||||
padding=int((filterSize - 1) / 2))
|
||||
self.conv2 = nn.Conv2d(
|
||||
outChannels,
|
||||
outChannels,
|
||||
filterSize,
|
||||
stride=1,
|
||||
padding=int((filterSize - 1) / 2))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.avg_pool2d(x, 2)
|
||||
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
||||
x = F.leaky_relu(self.conv2(x), negative_slope=0.1)
|
||||
return x
|
||||
|
||||
|
||||
class up(nn.Module):
|
||||
|
||||
def __init__(self, inChannels, outChannels):
|
||||
super(up, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inChannels, outChannels, 3, stride=1, padding=1)
|
||||
self.conv2 = nn.Conv2d(
|
||||
2 * outChannels, outChannels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, skpCn):
|
||||
x = F.interpolate(
|
||||
x,
|
||||
size=[skpCn.size(2), skpCn.size(3)],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
||||
x = F.leaky_relu(
|
||||
self.conv2(torch.cat((x, skpCn), 1)), negative_slope=0.1)
|
||||
return x
|
||||
|
||||
|
||||
class Small_UNet(nn.Module):
|
||||
|
||||
def __init__(self, inChannels, outChannels):
|
||||
super(Small_UNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inChannels, 32, 7, stride=1, padding=3)
|
||||
self.conv2 = nn.Conv2d(32, 32, 7, stride=1, padding=3)
|
||||
self.down1 = down(32, 64, 5)
|
||||
self.down2 = down(64, 128, 3)
|
||||
self.down3 = down(128, 128, 3)
|
||||
self.up1 = up(128, 128)
|
||||
self.up2 = up(128, 64)
|
||||
self.up3 = up(64, 32)
|
||||
self.conv3 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
||||
s1 = F.leaky_relu(self.conv2(x), negative_slope=0.1)
|
||||
s2 = self.down1(s1)
|
||||
s3 = self.down2(s2)
|
||||
x = self.down3(s3)
|
||||
x = self.up1(x, s3)
|
||||
x = self.up2(x, s2)
|
||||
x1 = self.up3(x, s1) # feature
|
||||
x = self.conv3(x1) # flow
|
||||
return x, x1
|
||||
|
||||
|
||||
class Small_UNet_Ds(nn.Module):
|
||||
|
||||
def __init__(self, inChannels, outChannels):
|
||||
super(Small_UNet_Ds, self).__init__()
|
||||
self.conv1_1 = nn.Conv2d(inChannels, 32, 5, stride=1, padding=2)
|
||||
self.conv1_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
|
||||
self.conv2_1 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
|
||||
|
||||
self.down1 = down(32, 64, 5)
|
||||
self.down2 = down(64, 128, 3)
|
||||
self.down3 = down(128, 128, 3)
|
||||
self.up1 = up(128, 128)
|
||||
self.up2 = up(128, 64)
|
||||
self.up3 = up(64, 32)
|
||||
self.conv3 = nn.Conv2d(32, 32, 3, stride=1, padding=1)
|
||||
self.conv4 = nn.Conv2d(32, outChannels, 3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
x0 = F.leaky_relu(self.conv1_1(x), negative_slope=0.1)
|
||||
x0 = F.leaky_relu(self.conv1_2(x0), negative_slope=0.1)
|
||||
|
||||
x = F.interpolate(
|
||||
x0,
|
||||
size=[x0.size(2) // 2, x0.size(3) // 2],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
x = F.leaky_relu(self.conv2_1(x), negative_slope=0.1)
|
||||
s1 = F.leaky_relu(self.conv2_2(x), negative_slope=0.1)
|
||||
|
||||
s2 = self.down1(s1)
|
||||
s3 = self.down2(s2)
|
||||
x = self.down3(s3)
|
||||
|
||||
x = self.up1(x, s3)
|
||||
x = self.up2(x, s2)
|
||||
x1 = self.up3(x, s1)
|
||||
|
||||
x1 = F.interpolate(
|
||||
x1,
|
||||
size=[x0.size(2), x0.size(3)],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
x1 = F.leaky_relu(self.conv3(x1), negative_slope=0.1) # feature
|
||||
x = self.conv4(x1) # flow
|
||||
return x, x1
|
||||
@@ -0,0 +1,115 @@
|
||||
# The implementation is adopted from QVI,
|
||||
# made publicly available at https://github.com/xuxy09/QVI
|
||||
|
||||
# class WarpLayer warps image x based on optical flow flo.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class FlowReversal(nn.Module):
|
||||
"""docstring for WarpLayer"""
|
||||
|
||||
def __init__(self, ):
|
||||
super(FlowReversal, self).__init__()
|
||||
|
||||
def forward(self, img, flo):
|
||||
"""
|
||||
-img: image (N, C, H, W)
|
||||
-flo: optical flow (N, 2, H, W)
|
||||
elements of flo is in [0, H] and [0, W] for dx, dy
|
||||
|
||||
"""
|
||||
|
||||
N, C, _, _ = img.size()
|
||||
|
||||
# translate start-point optical flow to end-point optical flow
|
||||
y = flo[:, 0:1:, :]
|
||||
x = flo[:, 1:2, :, :]
|
||||
|
||||
x = x.repeat(1, C, 1, 1)
|
||||
y = y.repeat(1, C, 1, 1)
|
||||
|
||||
x1 = torch.floor(x)
|
||||
x2 = x1 + 1
|
||||
y1 = torch.floor(y)
|
||||
y2 = y1 + 1
|
||||
|
||||
# firstly, get gaussian weights
|
||||
w11, w12, w21, w22 = self.get_gaussian_weights(x, y, x1, x2, y1, y2)
|
||||
|
||||
# secondly, sample each weighted corner
|
||||
img11, o11 = self.sample_one(img, x1, y1, w11)
|
||||
img12, o12 = self.sample_one(img, x1, y2, w12)
|
||||
img21, o21 = self.sample_one(img, x2, y1, w21)
|
||||
img22, o22 = self.sample_one(img, x2, y2, w22)
|
||||
|
||||
imgw = img11 + img12 + img21 + img22
|
||||
o = o11 + o12 + o21 + o22
|
||||
|
||||
return imgw, o
|
||||
|
||||
def get_gaussian_weights(self, x, y, x1, x2, y1, y2):
|
||||
w11 = torch.exp(-((x - x1)**2 + (y - y1)**2))
|
||||
w12 = torch.exp(-((x - x1)**2 + (y - y2)**2))
|
||||
w21 = torch.exp(-((x - x2)**2 + (y - y1)**2))
|
||||
w22 = torch.exp(-((x - x2)**2 + (y - y2)**2))
|
||||
|
||||
return w11, w12, w21, w22
|
||||
|
||||
def sample_one(self, img, shiftx, shifty, weight):
|
||||
"""
|
||||
Input:
|
||||
-img (N, C, H, W)
|
||||
-shiftx, shifty (N, c, H, W)
|
||||
"""
|
||||
|
||||
N, C, H, W = img.size()
|
||||
|
||||
# flatten all (all restored as Tensors)
|
||||
flat_shiftx = shiftx.view(-1)
|
||||
flat_shifty = shifty.view(-1)
|
||||
flat_basex = torch.arange(
|
||||
0, H, requires_grad=False).view(-1, 1)[None,
|
||||
None].cuda().long().repeat(
|
||||
N, C, 1, W).view(-1)
|
||||
flat_basey = torch.arange(
|
||||
0, W, requires_grad=False).view(1, -1)[None,
|
||||
None].cuda().long().repeat(
|
||||
N, C, H, 1).view(-1)
|
||||
flat_weight = weight.view(-1)
|
||||
flat_img = img.view(-1)
|
||||
|
||||
idxn = torch.arange(
|
||||
0, N,
|
||||
requires_grad=False).view(N, 1, 1,
|
||||
1).long().cuda().repeat(1, C, H,
|
||||
W).view(-1)
|
||||
idxc = torch.arange(
|
||||
0, C,
|
||||
requires_grad=False).view(1, C, 1,
|
||||
1).long().cuda().repeat(N, 1, H,
|
||||
W).view(-1)
|
||||
idxx = flat_shiftx.long() + flat_basex
|
||||
idxy = flat_shifty.long() + flat_basey
|
||||
|
||||
mask = idxx.ge(0) & idxx.lt(H) & idxy.ge(0) & idxy.lt(W)
|
||||
|
||||
ids = (idxn * C * H * W + idxc * H * W + idxx * W + idxy)
|
||||
ids_mask = torch.masked_select(ids, mask).clone().cuda()
|
||||
|
||||
img_warp = torch.zeros([
|
||||
N * C * H * W,
|
||||
]).cuda()
|
||||
img_warp.put_(
|
||||
ids_mask,
|
||||
torch.masked_select(flat_img * flat_weight, mask),
|
||||
accumulate=True)
|
||||
|
||||
one_warp = torch.zeros([
|
||||
N * C * H * W,
|
||||
]).cuda()
|
||||
one_warp.put_(
|
||||
ids_mask, torch.masked_select(flat_weight, mask), accumulate=True)
|
||||
|
||||
return img_warp.view(N, C, H, W), one_warp.view(N, C, H, W)
|
||||
@@ -0,0 +1,488 @@
|
||||
# Part of the implementation is borrowed and modified from QVI, publicly available at https://github.com/xuxy09/QVI
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.flow_reversal import \
|
||||
FlowReversal
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.IFNet_swin import \
|
||||
IFNet
|
||||
from modelscope.models.cv.video_frame_interpolation.interp_model.UNet import \
|
||||
Small_UNet_Ds
|
||||
|
||||
|
||||
class AcFusionLayer(nn.Module):
|
||||
|
||||
def __init__(self, ):
|
||||
super(AcFusionLayer, self).__init__()
|
||||
|
||||
def forward(self, flo10, flo12, flo21, flo23, t=0.5):
|
||||
return 0.5 * ((t + t**2) * flo12 - (t - t**2) * flo10), \
|
||||
0.5 * (((1 - t) + (1 - t)**2) * flo21 - ((1 - t) - (1 - t)**2) * flo23)
|
||||
# return 0.375 * flo12 - 0.125 * flo10, 0.375 * flo21 - 0.125 * flo23
|
||||
|
||||
|
||||
class Get_gradient(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Get_gradient, self).__init__()
|
||||
kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
|
||||
kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
|
||||
kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
|
||||
kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
|
||||
self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False)
|
||||
self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0] # R
|
||||
x1 = x[:, 1] # G
|
||||
x2 = x[:, 2] # B
|
||||
x0_v = F.conv2d(x0.unsqueeze(1), self.weight_v, padding=1)
|
||||
x0_h = F.conv2d(x0.unsqueeze(1), self.weight_h, padding=1)
|
||||
|
||||
x1_v = F.conv2d(x1.unsqueeze(1), self.weight_v, padding=1)
|
||||
x1_h = F.conv2d(x1.unsqueeze(1), self.weight_h, padding=1)
|
||||
|
||||
x2_v = F.conv2d(x2.unsqueeze(1), self.weight_v, padding=1)
|
||||
x2_h = F.conv2d(x2.unsqueeze(1), self.weight_h, padding=1)
|
||||
|
||||
x0 = torch.sqrt(torch.pow(x0_v, 2) + torch.pow(x0_h, 2) + 1e-6)
|
||||
x1 = torch.sqrt(torch.pow(x1_v, 2) + torch.pow(x1_h, 2) + 1e-6)
|
||||
x2 = torch.sqrt(torch.pow(x2_v, 2) + torch.pow(x2_h, 2) + 1e-6)
|
||||
|
||||
x = torch.cat([x0, x1, x2], dim=1)
|
||||
return x
|
||||
|
||||
|
||||
class LowPassFilter(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(LowPassFilter, self).__init__()
|
||||
kernel_lpf = [[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1],
|
||||
[1, 1, 1, 1, 1, 1, 1]]
|
||||
|
||||
kernel_lpf = torch.FloatTensor(kernel_lpf).unsqueeze(0).unsqueeze(
|
||||
0) / 49
|
||||
|
||||
self.weight_lpf = nn.Parameter(data=kernel_lpf, requires_grad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = x[:, 0]
|
||||
x1 = x[:, 1]
|
||||
y0 = F.conv2d(x0.unsqueeze(1), self.weight_lpf, padding=3)
|
||||
y1 = F.conv2d(x1.unsqueeze(1), self.weight_lpf, padding=3)
|
||||
|
||||
y = torch.cat([y0, y1], dim=1)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def backwarp(img, flow):
|
||||
_, _, H, W = img.size()
|
||||
|
||||
u = flow[:, 0, :, :]
|
||||
v = flow[:, 1, :, :]
|
||||
|
||||
gridX, gridY = np.meshgrid(np.arange(W), np.arange(H))
|
||||
|
||||
gridX = torch.tensor(
|
||||
gridX,
|
||||
requires_grad=False,
|
||||
).cuda()
|
||||
gridY = torch.tensor(
|
||||
gridY,
|
||||
requires_grad=False,
|
||||
).cuda()
|
||||
x = gridX.unsqueeze(0).expand_as(u).float() + u
|
||||
y = gridY.unsqueeze(0).expand_as(v).float() + v
|
||||
|
||||
x = 2 * (x / (W - 1) - 0.5)
|
||||
y = 2 * (y / (H - 1) - 0.5)
|
||||
|
||||
grid = torch.stack((x, y), dim=3)
|
||||
|
||||
imgOut = torch.nn.functional.grid_sample(
|
||||
img, grid, padding_mode='border', align_corners=True)
|
||||
|
||||
return imgOut
|
||||
|
||||
|
||||
class SmallMaskNet(nn.Module):
|
||||
"""A three-layer network for predicting mask"""
|
||||
|
||||
def __init__(self, input, output):
|
||||
super(SmallMaskNet, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input, 32, 5, padding=2)
|
||||
self.conv2 = nn.Conv2d(32, 16, 3, padding=1)
|
||||
self.conv3 = nn.Conv2d(16, output, 3, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.leaky_relu(self.conv1(x), negative_slope=0.1)
|
||||
x = F.leaky_relu(self.conv2(x), negative_slope=0.1)
|
||||
x = self.conv3(x)
|
||||
return x
|
||||
|
||||
|
||||
class StaticMaskNet(nn.Module):
|
||||
"""static mask"""
|
||||
|
||||
def __init__(self, input, output):
|
||||
super(StaticMaskNet, self).__init__()
|
||||
|
||||
modules_body = []
|
||||
modules_body.append(
|
||||
nn.Conv2d(
|
||||
in_channels=input,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
|
||||
modules_body.append(
|
||||
nn.Conv2d(
|
||||
in_channels=32,
|
||||
out_channels=32,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
|
||||
modules_body.append(
|
||||
nn.Conv2d(
|
||||
in_channels=32,
|
||||
out_channels=16,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
|
||||
modules_body.append(
|
||||
nn.Conv2d(
|
||||
in_channels=16,
|
||||
out_channels=16,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
modules_body.append(nn.LeakyReLU(inplace=False, negative_slope=0.1))
|
||||
modules_body.append(
|
||||
nn.Conv2d(
|
||||
in_channels=16,
|
||||
out_channels=output,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1))
|
||||
modules_body.append(nn.Sigmoid())
|
||||
|
||||
self.body = nn.Sequential(*modules_body)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.body(x)
|
||||
return y
|
||||
|
||||
|
||||
def tensor_erode(bin_img, ksize=5):
|
||||
B, C, H, W = bin_img.shape
|
||||
pad = (ksize - 1) // 2
|
||||
bin_img = F.pad(bin_img, [pad, pad, pad, pad], mode='constant', value=0)
|
||||
|
||||
patches = bin_img.unfold(dimension=2, size=ksize, step=1)
|
||||
patches = patches.unfold(dimension=3, size=ksize, step=1)
|
||||
|
||||
eroded, _ = patches.reshape(B, C, H, W, -1).max(dim=-1)
|
||||
return eroded
|
||||
|
||||
|
||||
class QVI_inter_Ds(nn.Module):
|
||||
"""Given flow, implement Quadratic Video Interpolation"""
|
||||
|
||||
def __init__(self, debug_en=False, is_training=False):
|
||||
super(QVI_inter_Ds, self).__init__()
|
||||
self.acc = AcFusionLayer()
|
||||
self.fwarp = FlowReversal()
|
||||
self.refinenet = Small_UNet_Ds(20, 8)
|
||||
self.masknet = SmallMaskNet(38, 1)
|
||||
|
||||
self.staticnet = StaticMaskNet(56, 1)
|
||||
self.lpfilter = LowPassFilter()
|
||||
|
||||
self.get_grad = Get_gradient()
|
||||
self.debug_en = debug_en
|
||||
self.is_training = is_training
|
||||
|
||||
def fill_flow_hole(self, ft, norm, ft_fill):
|
||||
(N, C, H, W) = ft.shape
|
||||
ft[norm == 0] = ft_fill[norm == 0]
|
||||
|
||||
ft_1 = self.lpfilter(ft.clone())
|
||||
ft_ds = torch.nn.functional.interpolate(
|
||||
input=ft_1,
|
||||
size=(H // 4, W // 4),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
ft_up = torch.nn.functional.interpolate(
|
||||
input=ft_ds, size=(H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
ft[norm == 0] = ft_up[norm == 0]
|
||||
|
||||
return ft
|
||||
|
||||
def forward(self, F10_Ds, F12_Ds, F21_Ds, F23_Ds, I1_Ds, I2_Ds, I1, I2, t):
|
||||
if F12_Ds is None or F21_Ds is None:
|
||||
return I1
|
||||
|
||||
if F10_Ds is not None and F23_Ds is not None:
|
||||
F1t_Ds, F2t_Ds = self.acc(F10_Ds, F12_Ds, F21_Ds, F23_Ds, t)
|
||||
|
||||
else:
|
||||
F1t_Ds = t * F12_Ds
|
||||
F2t_Ds = (1 - t) * F21_Ds
|
||||
|
||||
# Flow Reversal
|
||||
F1t_Ds2 = F.interpolate(
|
||||
F1t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3
|
||||
F2t_Ds2 = F.interpolate(
|
||||
F2t_Ds, scale_factor=1.0 / 3, mode='nearest') / 3
|
||||
Ft1_Ds2, norm1_Ds2 = self.fwarp(F1t_Ds2, F1t_Ds2)
|
||||
Ft1_Ds2 = -Ft1_Ds2
|
||||
Ft2_Ds2, norm2_Ds2 = self.fwarp(F2t_Ds2, F2t_Ds2)
|
||||
Ft2_Ds2 = -Ft2_Ds2
|
||||
|
||||
Ft1_Ds2[norm1_Ds2 > 0] \
|
||||
= Ft1_Ds2[norm1_Ds2 > 0] / norm1_Ds2[norm1_Ds2 > 0].clone()
|
||||
Ft2_Ds2[norm2_Ds2 > 0] \
|
||||
= Ft2_Ds2[norm2_Ds2 > 0] / norm2_Ds2[norm2_Ds2 > 0].clone()
|
||||
if 1:
|
||||
Ft1_Ds2_fill = -F1t_Ds2
|
||||
Ft2_Ds2_fill = -F2t_Ds2
|
||||
Ft1_Ds2 = self.fill_flow_hole(Ft1_Ds2, norm1_Ds2, Ft1_Ds2_fill)
|
||||
Ft2_Ds2 = self.fill_flow_hole(Ft2_Ds2, norm2_Ds2, Ft2_Ds2_fill)
|
||||
|
||||
Ft1_Ds = F.interpolate(
|
||||
Ft1_Ds2, size=[F1t_Ds.size(2), F1t_Ds.size(3)], mode='nearest') * 3
|
||||
Ft2_Ds = F.interpolate(
|
||||
Ft2_Ds2, size=[F2t_Ds.size(2), F2t_Ds.size(3)], mode='nearest') * 3
|
||||
|
||||
I1t_Ds = backwarp(I1_Ds, Ft1_Ds)
|
||||
I2t_Ds = backwarp(I2_Ds, Ft2_Ds)
|
||||
|
||||
output_Ds, feature_Ds = self.refinenet(
|
||||
torch.cat(
|
||||
[I1_Ds, I2_Ds, I1t_Ds, I2t_Ds, F12_Ds, F21_Ds, Ft1_Ds, Ft2_Ds],
|
||||
dim=1))
|
||||
|
||||
# Adaptive filtering
|
||||
Ft1r_Ds = backwarp(
|
||||
Ft1_Ds, 10 * torch.tanh(output_Ds[:, 4:6])) + output_Ds[:, :2]
|
||||
Ft2r_Ds = backwarp(
|
||||
Ft2_Ds, 10 * torch.tanh(output_Ds[:, 6:8])) + output_Ds[:, 2:4]
|
||||
|
||||
# warping and fusing
|
||||
I1tf_Ds = backwarp(I1_Ds, Ft1r_Ds)
|
||||
I2tf_Ds = backwarp(I2_Ds, Ft2r_Ds)
|
||||
|
||||
G1_Ds = self.get_grad(I1_Ds)
|
||||
G2_Ds = self.get_grad(I2_Ds)
|
||||
G1tf_Ds = backwarp(G1_Ds, Ft1r_Ds)
|
||||
G2tf_Ds = backwarp(G2_Ds, Ft2r_Ds)
|
||||
|
||||
M_Ds = torch.sigmoid(
|
||||
self.masknet(torch.cat([I1tf_Ds, I2tf_Ds, feature_Ds],
|
||||
dim=1))).repeat(1, 3, 1, 1)
|
||||
|
||||
Ft1r = F.interpolate(
|
||||
Ft1r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
Ft2r = F.interpolate(
|
||||
Ft2r_Ds * 2, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
I1tf = backwarp(I1, Ft1r)
|
||||
I2tf = backwarp(I2, Ft2r)
|
||||
|
||||
M = F.interpolate(
|
||||
M_Ds, scale_factor=2, mode='bilinear', align_corners=False)
|
||||
|
||||
# fuse
|
||||
It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \
|
||||
/ ((1 - t) * M + t * (1 - M)).clone()
|
||||
|
||||
# static blending
|
||||
It_static = (1 - t) * I1 + t * I2
|
||||
tmp = torch.cat((I1tf_Ds, I2tf_Ds, G1tf_Ds, G2tf_Ds, I1_Ds, I2_Ds,
|
||||
G1_Ds, G2_Ds, feature_Ds),
|
||||
dim=1)
|
||||
M_static_Ds = self.staticnet(tmp)
|
||||
M_static_dilate = tensor_erode(M_static_Ds)
|
||||
M_static_dilate = tensor_erode(M_static_dilate)
|
||||
M_static = F.interpolate(
|
||||
M_static_dilate,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
It_warp = (1 - M_static) * It_warp + M_static * It_static
|
||||
|
||||
if self.is_training:
|
||||
return It_warp, Ft1r, Ft2r
|
||||
else:
|
||||
if self.debug_en:
|
||||
return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r
|
||||
else:
|
||||
return It_warp
|
||||
|
||||
|
||||
class QVI_inter(nn.Module):
|
||||
"""Given flow, implement Quadratic Video Interpolation"""
|
||||
|
||||
def __init__(self, debug_en=False, is_training=False):
|
||||
super(QVI_inter, self).__init__()
|
||||
self.acc = AcFusionLayer()
|
||||
self.fwarp = FlowReversal()
|
||||
self.refinenet = Small_UNet_Ds(20, 8)
|
||||
self.masknet = SmallMaskNet(38, 1)
|
||||
|
||||
self.staticnet = StaticMaskNet(56, 1)
|
||||
self.lpfilter = LowPassFilter()
|
||||
|
||||
self.get_grad = Get_gradient()
|
||||
self.debug_en = debug_en
|
||||
self.is_training = is_training
|
||||
|
||||
def fill_flow_hole(self, ft, norm, ft_fill):
|
||||
(N, C, H, W) = ft.shape
|
||||
ft[norm == 0] = ft_fill[norm == 0]
|
||||
|
||||
ft_1 = self.lpfilter(ft.clone())
|
||||
ft_ds = torch.nn.functional.interpolate(
|
||||
input=ft_1,
|
||||
size=(H // 4, W // 4),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
ft_up = torch.nn.functional.interpolate(
|
||||
input=ft_ds, size=(H, W), mode='bilinear', align_corners=False)
|
||||
|
||||
ft[norm == 0] = ft_up[norm == 0]
|
||||
|
||||
return ft
|
||||
|
||||
def forward(self, F10, F12, F21, F23, I1, I2, t):
|
||||
if F12 is None or F21 is None:
|
||||
return I1
|
||||
|
||||
if F10 is not None and F23 is not None:
|
||||
F1t, F2t = self.acc(F10, F12, F21, F23, t)
|
||||
|
||||
else:
|
||||
F1t = t * F12
|
||||
F2t = (1 - t) * F21
|
||||
|
||||
# Flow Reversal
|
||||
F1t_Ds = F.interpolate(F1t, scale_factor=1.0 / 3, mode='nearest') / 3
|
||||
F2t_Ds = F.interpolate(F2t, scale_factor=1.0 / 3, mode='nearest') / 3
|
||||
Ft1_Ds, norm1_Ds = self.fwarp(F1t_Ds, F1t_Ds)
|
||||
Ft1_Ds = -Ft1_Ds
|
||||
Ft2_Ds, norm2_Ds = self.fwarp(F2t_Ds, F2t_Ds)
|
||||
Ft2_Ds = -Ft2_Ds
|
||||
|
||||
Ft1_Ds[norm1_Ds > 0] \
|
||||
= Ft1_Ds[norm1_Ds > 0] / norm1_Ds[norm1_Ds > 0].clone()
|
||||
Ft2_Ds[norm2_Ds > 0] \
|
||||
= Ft2_Ds[norm2_Ds > 0] / norm2_Ds[norm2_Ds > 0].clone()
|
||||
if 1:
|
||||
Ft1_fill = -F1t_Ds
|
||||
Ft2_fill = -F2t_Ds
|
||||
Ft1_Ds = self.fill_flow_hole(Ft1_Ds, norm1_Ds, Ft1_fill)
|
||||
Ft2_Ds = self.fill_flow_hole(Ft2_Ds, norm2_Ds, Ft2_fill)
|
||||
|
||||
Ft1 = F.interpolate(
|
||||
Ft1_Ds, size=[F1t.size(2), F1t.size(3)], mode='nearest') * 3
|
||||
Ft2 = F.interpolate(
|
||||
Ft2_Ds, size=[F2t.size(2), F2t.size(3)], mode='nearest') * 3
|
||||
|
||||
I1t = backwarp(I1, Ft1)
|
||||
I2t = backwarp(I2, Ft2)
|
||||
|
||||
output, feature = self.refinenet(
|
||||
torch.cat([I1, I2, I1t, I2t, F12, F21, Ft1, Ft2], dim=1))
|
||||
|
||||
# Adaptive filtering
|
||||
Ft1r = backwarp(Ft1, 10 * torch.tanh(output[:, 4:6])) + output[:, :2]
|
||||
Ft2r = backwarp(Ft2, 10 * torch.tanh(output[:, 6:8])) + output[:, 2:4]
|
||||
|
||||
# warping and fusing
|
||||
I1tf = backwarp(I1, Ft1r)
|
||||
I2tf = backwarp(I2, Ft2r)
|
||||
|
||||
M = torch.sigmoid(
|
||||
self.masknet(torch.cat([I1tf, I2tf, feature],
|
||||
dim=1))).repeat(1, 3, 1, 1)
|
||||
|
||||
It_warp = ((1 - t) * M * I1tf + t * (1 - M) * I2tf) \
|
||||
/ ((1 - t) * M + t * (1 - M)).clone()
|
||||
|
||||
G1 = self.get_grad(I1)
|
||||
G2 = self.get_grad(I2)
|
||||
G1tf = backwarp(G1, Ft1r)
|
||||
G2tf = backwarp(G2, Ft2r)
|
||||
|
||||
# static blending
|
||||
It_static = (1 - t) * I1 + t * I2
|
||||
M_static = self.staticnet(
|
||||
torch.cat([I1tf, I2tf, G1tf, G2tf, I1, I2, G1, G2, feature],
|
||||
dim=1))
|
||||
M_static_dilate = tensor_erode(M_static)
|
||||
M_static_dilate = tensor_erode(M_static_dilate)
|
||||
It_warp = (1 - M_static_dilate) * It_warp + M_static_dilate * It_static
|
||||
|
||||
if self.is_training:
|
||||
return It_warp, Ft1r, Ft2r
|
||||
else:
|
||||
if self.debug_en:
|
||||
return It_warp, M, M_static, I1tf, I2tf, Ft1r, Ft2r
|
||||
else:
|
||||
return It_warp
|
||||
|
||||
|
||||
class InterpNetDs(nn.Module):
|
||||
|
||||
def __init__(self, debug_en=False, is_training=False):
|
||||
super(InterpNetDs, self).__init__()
|
||||
self.ifnet = IFNet()
|
||||
self.internet = QVI_inter_Ds(
|
||||
debug_en=debug_en, is_training=is_training)
|
||||
|
||||
def forward(self,
|
||||
img1,
|
||||
img2,
|
||||
F10_up,
|
||||
F12_up,
|
||||
F21_up,
|
||||
F23_up,
|
||||
UHD=2,
|
||||
timestep=0.5):
|
||||
F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD)
|
||||
It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep)
|
||||
|
||||
return It_warp
|
||||
|
||||
|
||||
class InterpNet(nn.Module):
|
||||
|
||||
def __init__(self, debug_en=False, is_training=False):
|
||||
super(InterpNet, self).__init__()
|
||||
self.ifnet = IFNet()
|
||||
self.internet = QVI_inter(debug_en=debug_en, is_training=is_training)
|
||||
|
||||
def forward(self,
|
||||
img1,
|
||||
img2,
|
||||
F10_up,
|
||||
F12_up,
|
||||
F21_up,
|
||||
F23_up,
|
||||
UHD=2,
|
||||
timestep=0.5):
|
||||
F12, F21 = self.ifnet(img1, img2, F12_up, F21_up, UHD)
|
||||
It_warp = self.internet(F10_up, F12, F21, F23_up, img1, img2, timestep)
|
||||
|
||||
return It_warp
|
||||
@@ -0,0 +1,989 @@
|
||||
# The implementation is adopted from VFIformer,
|
||||
# made publicly available at https://github.com/dvlab-research/VFIformer
|
||||
|
||||
# -----------------------------------------------------------------------------------
|
||||
# modified from:
|
||||
# SwinIR: Image Restoration Using Swin Transformer, https://github.com/JingyunLiang/SwinIR
|
||||
# -----------------------------------------------------------------------------------
|
||||
|
||||
import functools
|
||||
import math
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as checkpoint
|
||||
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_features,
|
||||
hidden_features=None,
|
||||
out_features=None,
|
||||
act_layer=nn.GELU,
|
||||
drop=0.):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: (B, H, W, C)
|
||||
window_size (int): window size
|
||||
|
||||
Returns:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
"""
|
||||
B, H, W, C = x.shape
|
||||
x = x.view(B, H // window_size, window_size, W // window_size, window_size,
|
||||
C)
|
||||
windows = x.permute(0, 1, 3, 2, 4,
|
||||
5).contiguous().view(-1, window_size, window_size, C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: (num_windows*B, window_size, window_size, C)
|
||||
window_size (int): Window size
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
|
||||
Returns:
|
||||
x: (B, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
||||
x = windows.view(B, H // window_size, W // window_size, window_size,
|
||||
window_size, -1)
|
||||
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1).contiguous())
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask is not None:
|
||||
nW = mask.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class WindowCrossAttention(nn.Module):
|
||||
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
||||
It supports both of shifted and non-shifted window.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
window_size (tuple[int]): The height and width of the window.
|
||||
num_heads (int): Number of attention heads.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
||||
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
||||
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
window_size,
|
||||
num_heads,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
attn_drop=0.,
|
||||
proj_drop=0.):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = qk_scale or head_dim**-0.5
|
||||
|
||||
# define a parameter table of relative position bias
|
||||
self.relative_position_bias_table_x = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
self.relative_position_bias_table_y = nn.Parameter(
|
||||
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1),
|
||||
num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
||||
|
||||
# get pair-wise relative position index for each token inside the window
|
||||
coords_h = torch.arange(self.window_size[0])
|
||||
coords_w = torch.arange(self.window_size[1])
|
||||
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
||||
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
||||
relative_coords = coords_flatten[:, :,
|
||||
None] - coords_flatten[:,
|
||||
None, :] # 2, Wh*Ww, Wh*Ww
|
||||
relative_coords = relative_coords.permute(
|
||||
1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
||||
relative_coords[:, :,
|
||||
0] += self.window_size[0] - 1 # shift to start from 0
|
||||
relative_coords[:, :, 1] += self.window_size[1] - 1
|
||||
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
||||
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
||||
self.register_buffer('relative_position_index',
|
||||
relative_position_index)
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.merge1 = nn.Linear(dim * 2, dim)
|
||||
self.merge2 = nn.Linear(dim, dim)
|
||||
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
trunc_normal_(self.relative_position_bias_table_x, std=.02)
|
||||
trunc_normal_(self.relative_position_bias_table_y, std=.02)
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x, y, mask_x=None, mask_y=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (num_windows*B, N, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[
|
||||
2] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1).contiguous())
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table_x[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask_x is not None:
|
||||
nW = mask_x.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask_x.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
|
||||
|
||||
B_, N, C = y.shape
|
||||
kv = self.kv(y).reshape(B_, N, 2, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
k, v = kv[0], kv[
|
||||
1] # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1).contiguous())
|
||||
|
||||
relative_position_bias = self.relative_position_bias_table_y[
|
||||
self.relative_position_index.view(-1)].view(
|
||||
self.window_size[0] * self.window_size[1],
|
||||
self.window_size[0] * self.window_size[1],
|
||||
-1) # Wh*Ww,Wh*Ww,nH
|
||||
relative_position_bias = relative_position_bias.permute(
|
||||
2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
||||
attn = attn + relative_position_bias.unsqueeze(0)
|
||||
|
||||
if mask_y is not None:
|
||||
nW = mask_y.shape[0]
|
||||
attn = attn.view(B_ // nW, nW, self.num_heads, N,
|
||||
N) + mask_y.unsqueeze(1).unsqueeze(0)
|
||||
attn = attn.view(-1, self.num_heads, N, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
y = (attn @ v).transpose(1, 2).reshape(B_, N, C).contiguous()
|
||||
|
||||
x = self.merge2(self.act(self.merge1(torch.cat([x, y], dim=-1)))) + x
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
||||
|
||||
def flops(self, N):
|
||||
# calculate flops for 1 window with token length of N
|
||||
flops = 0
|
||||
# qkv = self.qkv(x)
|
||||
flops += N * self.dim * 3 * self.dim
|
||||
# attn = (q @ k.transpose(-2, -1))
|
||||
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
||||
# x = (attn @ v)
|
||||
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
||||
# x = self.proj(x)
|
||||
flops += N * self.dim * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class TFL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
input_resolution,
|
||||
num_heads,
|
||||
window_size=7,
|
||||
shift_size=0,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
act_layer=nn.GELU,
|
||||
norm_layer=nn.LayerNorm,
|
||||
use_crossattn=False):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.shift_size = shift_size
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.use_crossattn = use_crossattn
|
||||
|
||||
if min(self.input_resolution) <= self.window_size:
|
||||
# if window size is larger than input resolution, we don't partition windows
|
||||
self.shift_size = 0
|
||||
self.window_size = min(self.input_resolution)
|
||||
assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
if not use_crossattn:
|
||||
self.attn = WindowAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
else:
|
||||
self.attn = WindowCrossAttention(
|
||||
dim,
|
||||
window_size=to_2tuple(self.window_size),
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop)
|
||||
|
||||
self.drop_path = DropPath(
|
||||
drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop)
|
||||
|
||||
if self.shift_size > 0:
|
||||
if not use_crossattn:
|
||||
attn_mask = self.calculate_mask(self.input_resolution)
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
else:
|
||||
attn_mask_x = self.calculate_mask(self.input_resolution)
|
||||
attn_mask_y = self.calculate_mask2(self.input_resolution)
|
||||
self.register_buffer('attn_mask_x', attn_mask_x)
|
||||
self.register_buffer('attn_mask_y', attn_mask_y)
|
||||
|
||||
else:
|
||||
if not use_crossattn:
|
||||
attn_mask = None
|
||||
self.register_buffer('attn_mask', attn_mask)
|
||||
else:
|
||||
attn_mask_x = None
|
||||
attn_mask_y = None
|
||||
self.register_buffer('attn_mask_x', attn_mask_x)
|
||||
self.register_buffer('attn_mask_y', attn_mask_y)
|
||||
|
||||
def calculate_mask(self, x_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = x_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1,
|
||||
self.window_size * self.window_size)
|
||||
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
def calculate_mask2(self, x_size):
|
||||
# calculate attention mask for SW-MSA
|
||||
H, W = x_size
|
||||
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
||||
h_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
w_slices = (slice(0, -self.window_size),
|
||||
slice(-self.window_size,
|
||||
-self.shift_size), slice(-self.shift_size, None))
|
||||
cnt = 0
|
||||
for h in h_slices:
|
||||
for w in w_slices:
|
||||
img_mask[:, h, w, :] = cnt
|
||||
cnt += 1
|
||||
|
||||
mask_windows = window_partition(
|
||||
img_mask, self.window_size) # nW, window_size, window_size, 1
|
||||
mask_windows = mask_windows.view(-1,
|
||||
self.window_size * self.window_size)
|
||||
|
||||
# downscale
|
||||
img_mask_down = F.interpolate(
|
||||
img_mask.permute(0, 3, 1, 2).contiguous(),
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
img_mask_down = F.pad(
|
||||
img_mask_down, (self.window_size // 4, self.window_size // 4,
|
||||
self.window_size // 4, self.window_size // 4),
|
||||
mode='reflect')
|
||||
mask_windows_down = F.unfold(
|
||||
img_mask_down,
|
||||
kernel_size=self.window_size,
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=self.window_size // 2)
|
||||
mask_windows_down = mask_windows_down.view(
|
||||
self.window_size * self.window_size,
|
||||
-1).permute(1, 0).contiguous() # nW, window_size*window_size
|
||||
|
||||
attn_mask = mask_windows_down.unsqueeze(1) - mask_windows.unsqueeze(
|
||||
2) # nW, window_size*window_size, window_size*window_size
|
||||
attn_mask = attn_mask.masked_fill(attn_mask != 0,
|
||||
float(-100.0)).masked_fill(
|
||||
attn_mask == 0, float(0.0))
|
||||
|
||||
return attn_mask
|
||||
|
||||
def forward(self, x, x_size):
|
||||
H, W = x_size
|
||||
B, L, C = x.shape
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
# cyclic shift
|
||||
if self.shift_size > 0:
|
||||
shifted_x = torch.roll(
|
||||
x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
||||
else:
|
||||
shifted_x = x
|
||||
|
||||
# partition windows
|
||||
x_windows = window_partition(
|
||||
shifted_x, self.window_size) # nW*B, window_size, window_size, C
|
||||
x_windows = x_windows.view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
|
||||
if not self.use_crossattn:
|
||||
if self.input_resolution == x_size:
|
||||
attn_windows = self.attn(
|
||||
x_windows,
|
||||
mask=self.attn_mask) # nW*B, window_size*window_size, C
|
||||
else:
|
||||
attn_windows = self.attn(
|
||||
x_windows, mask=self.calculate_mask(x_size).to(x.device))
|
||||
else:
|
||||
shifted_x_down = F.interpolate(
|
||||
shifted_x.permute(0, 3, 1, 2).contiguous(),
|
||||
scale_factor=0.5,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
shifted_x_down = F.pad(
|
||||
shifted_x_down, (self.window_size // 4, self.window_size // 4,
|
||||
self.window_size // 4, self.window_size // 4),
|
||||
mode='reflect')
|
||||
x_windows_down = F.unfold(
|
||||
shifted_x_down,
|
||||
kernel_size=self.window_size,
|
||||
dilation=1,
|
||||
padding=0,
|
||||
stride=self.window_size // 2)
|
||||
x_windows_down = x_windows_down.view(
|
||||
B, C, self.window_size * self.window_size, -1)
|
||||
x_windows_down = x_windows_down.permute(
|
||||
0, 3, 2,
|
||||
1).contiguous().view(-1, self.window_size * self.window_size,
|
||||
C) # nW*B, window_size*window_size, C
|
||||
|
||||
if self.input_resolution == x_size:
|
||||
attn_windows = self.attn(
|
||||
x_windows,
|
||||
x_windows_down,
|
||||
mask_x=self.attn_mask_x,
|
||||
mask_y=self.attn_mask_y
|
||||
) # nW*B, window_size*window_size, C
|
||||
else:
|
||||
attn_windows = self.attn(
|
||||
x_windows,
|
||||
x_windows_down,
|
||||
mask_x=self.calculate_mask(x_size).to(x.device),
|
||||
mask_y=self.calculate_mask2(x_size).to(x.device))
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, self.window_size,
|
||||
self.window_size, C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, H,
|
||||
W) # B H' W' C
|
||||
|
||||
# reverse cyclic shift
|
||||
if self.shift_size > 0:
|
||||
x = torch.roll(
|
||||
shifted_x,
|
||||
shifts=(self.shift_size, self.shift_size),
|
||||
dims=(1, 2))
|
||||
else:
|
||||
x = shifted_x
|
||||
|
||||
x = x.view(B, H * W, C)
|
||||
# FFN
|
||||
x = shortcut + self.drop_path(x)
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, ' \
|
||||
f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}'
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.input_resolution
|
||||
# norm1
|
||||
flops += self.dim * H * W
|
||||
# W-MSA/SW-MSA
|
||||
nW = H * W / self.window_size / self.window_size
|
||||
flops += nW * self.attn.flops(self.window_size * self.window_size)
|
||||
# mlp
|
||||
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
||||
# norm2
|
||||
flops += self.dim * H * W
|
||||
return flops
|
||||
|
||||
|
||||
class PatchMerging(nn.Module):
|
||||
r""" Patch Merging Layer.
|
||||
|
||||
Args:
|
||||
input_resolution (tuple[int]): Resolution of input feature.
|
||||
dim (int): Number of input channels.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
"""
|
||||
|
||||
def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.dim = dim
|
||||
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
|
||||
self.norm = norm_layer(4 * dim)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: B, H*W, C
|
||||
"""
|
||||
H, W = self.input_resolution
|
||||
B, L, C = x.shape
|
||||
assert L == H * W, 'input feature has wrong size'
|
||||
assert H % 2 == 0 and W % 2 == 0, f'x size ({H}*{W}) are not even.'
|
||||
|
||||
x = x.view(B, H, W, C)
|
||||
|
||||
x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
|
||||
x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
|
||||
x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
|
||||
x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
|
||||
x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
|
||||
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
|
||||
|
||||
x = self.norm(x)
|
||||
x = self.reduction(x)
|
||||
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'input_resolution={self.input_resolution}, dim={self.dim}'
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.dim
|
||||
flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
|
||||
return flops
|
||||
|
||||
|
||||
class BasicLayer(nn.Module):
|
||||
""" A basic Swin Transformer layer for one stage.
|
||||
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
input_resolution (tuple[int]): Input resolution.
|
||||
depth (int): Number of blocks.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (int): Local window size.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
||||
drop (float, optional): Dropout rate. Default: 0.0
|
||||
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
||||
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
||||
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
use_crossattn=None):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.depth = depth
|
||||
self.use_checkpoint = use_checkpoint
|
||||
|
||||
if use_crossattn is None:
|
||||
use_crossattn = [False for i in range(depth)]
|
||||
|
||||
# build blocks
|
||||
self.blocks = nn.ModuleList([
|
||||
TFL(dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
shift_size=0 if (i % 2 == 0) else window_size // 2,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path[i]
|
||||
if isinstance(drop_path, list) else drop_path,
|
||||
norm_layer=norm_layer,
|
||||
use_crossattn=use_crossattn[i]) for i in range(depth)
|
||||
])
|
||||
|
||||
# patch merging layer
|
||||
if downsample is not None:
|
||||
self.downsample = downsample(
|
||||
input_resolution, dim=dim, norm_layer=norm_layer)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x, x_size):
|
||||
for blk in self.blocks:
|
||||
if self.use_checkpoint:
|
||||
x = checkpoint.checkpoint(blk, x, x_size)
|
||||
else:
|
||||
x = blk(x, x_size)
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
return x
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
for blk in self.blocks:
|
||||
flops += blk.flops()
|
||||
if self.downsample is not None:
|
||||
flops += self.downsample.flops()
|
||||
return flops
|
||||
|
||||
|
||||
class RTFL(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
input_resolution,
|
||||
depth,
|
||||
num_heads,
|
||||
window_size,
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
qk_scale=None,
|
||||
drop=0.,
|
||||
attn_drop=0.,
|
||||
drop_path=0.,
|
||||
norm_layer=nn.LayerNorm,
|
||||
downsample=None,
|
||||
use_checkpoint=False,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
resi_connection='1conv',
|
||||
use_crossattn=None):
|
||||
super(RTFL, self).__init__()
|
||||
|
||||
self.dim = dim
|
||||
self.input_resolution = input_resolution
|
||||
self.use_crossattn = use_crossattn
|
||||
|
||||
self.residual_group = BasicLayer(
|
||||
dim=dim,
|
||||
input_resolution=input_resolution,
|
||||
depth=depth,
|
||||
num_heads=num_heads,
|
||||
window_size=window_size,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale,
|
||||
drop=drop,
|
||||
attn_drop=attn_drop,
|
||||
drop_path=drop_path,
|
||||
norm_layer=norm_layer,
|
||||
downsample=downsample,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_crossattn=use_crossattn)
|
||||
|
||||
if resi_connection == '1conv':
|
||||
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
|
||||
elif resi_connection == '3conv':
|
||||
# to save parameters and memory
|
||||
self.conv = nn.Sequential(
|
||||
nn.Conv2d(dim, dim // 4, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
|
||||
nn.LeakyReLU(negative_slope=0.2, inplace=True),
|
||||
nn.Conv2d(dim // 4, dim, 3, 1, 1))
|
||||
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=0,
|
||||
embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
self.patch_unembed = PatchUnEmbed(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=0,
|
||||
embed_dim=dim,
|
||||
norm_layer=None)
|
||||
|
||||
def forward(self, x, x_size):
|
||||
return self.patch_embed(
|
||||
self.conv(
|
||||
self.patch_unembed(self.residual_group(x, x_size),
|
||||
x_size))) + x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
flops += self.residual_group.flops()
|
||||
H, W = self.input_resolution
|
||||
flops += H * W * self.dim * self.dim * 9
|
||||
flops += self.patch_embed.flops()
|
||||
flops += self.patch_unembed.flops()
|
||||
|
||||
return flops
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
r""" Image to Patch Embedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
if norm_layer is not None:
|
||||
self.norm = norm_layer(embed_dim)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(2).transpose(1, 2).contiguous() # B Ph*Pw C
|
||||
if self.norm is not None:
|
||||
x = self.norm(x)
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
H, W = self.img_size
|
||||
if self.norm is not None:
|
||||
flops += H * W * self.embed_dim
|
||||
return flops
|
||||
|
||||
|
||||
class PatchUnEmbed(nn.Module):
|
||||
r""" Image to Patch Unembedding
|
||||
|
||||
Args:
|
||||
img_size (int): Image size. Default: 224.
|
||||
patch_size (int): Patch token size. Default: 4.
|
||||
in_chans (int): Number of input image channels. Default: 3.
|
||||
embed_dim (int): Number of linear projection output channels. Default: 96.
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
patch_size=4,
|
||||
in_chans=3,
|
||||
embed_dim=96,
|
||||
norm_layer=None):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
patches_resolution = [
|
||||
img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.patches_resolution = patches_resolution
|
||||
self.num_patches = patches_resolution[0] * patches_resolution[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def forward(self, x, x_size):
|
||||
B, HW, C = x.shape
|
||||
x = x.transpose(1, 2).contiguous().view(B, self.embed_dim, x_size[0],
|
||||
x_size[1]) # B Ph*Pw C
|
||||
return x
|
||||
|
||||
def flops(self):
|
||||
flops = 0
|
||||
return flops
|
||||
|
||||
|
||||
class Upsample(nn.Sequential):
|
||||
"""Upsample module.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat):
|
||||
m = []
|
||||
if (scale & (scale - 1)) == 0: # scale = 2^n
|
||||
for _ in range(int(math.log(scale, 2))):
|
||||
m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(2))
|
||||
elif scale == 3:
|
||||
m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(3))
|
||||
else:
|
||||
raise ValueError(f'scale {scale} is not supported. '
|
||||
'Supported scales: 2^n and 3.')
|
||||
super(Upsample, self).__init__(*m)
|
||||
|
||||
|
||||
class UpsampleOneStep(nn.Sequential):
|
||||
"""UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
|
||||
Used in lightweight SR to save parameters.
|
||||
|
||||
Args:
|
||||
scale (int): Scale factor. Supported scales: 2^n and 3.
|
||||
num_feat (int): Channel number of intermediate features.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
|
||||
self.num_feat = num_feat
|
||||
self.input_resolution = input_resolution
|
||||
m = []
|
||||
m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
|
||||
m.append(nn.PixelShuffle(scale))
|
||||
super(UpsampleOneStep, self).__init__(*m)
|
||||
|
||||
def flops(self):
|
||||
H, W = self.input_resolution
|
||||
flops = H * W * self.num_feat * 3 * 9
|
||||
return flops
|
||||
@@ -0,0 +1,97 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def calc_hist(img_tensor):
|
||||
hist = torch.histc(img_tensor, bins=64, min=0, max=255)
|
||||
return hist / img_tensor.numel()
|
||||
|
||||
|
||||
def do_scene_detect(F01_tensor, F10_tensor, img0_tensor, img1_tensor):
|
||||
device = img0_tensor.device
|
||||
scene_change = False
|
||||
img0_tensor = img0_tensor.clone()
|
||||
img1_tensor = img1_tensor.clone()
|
||||
|
||||
img0_gray = 0.299 * img0_tensor[:, 0:
|
||||
1] + 0.587 * img0_tensor[:, 1:
|
||||
2] + 0.114 * img0_tensor[:,
|
||||
2:
|
||||
3]
|
||||
img1_gray = 0.299 * img1_tensor[:, 0:
|
||||
1] + 0.587 * img1_tensor[:, 1:
|
||||
2] + 0.114 * img1_tensor[:,
|
||||
2:
|
||||
3]
|
||||
img0_gray = torch.clamp(img0_gray, 0, 255).byte().float().cpu()
|
||||
img1_gray = torch.clamp(img1_gray, 0, 255).byte().float().cpu()
|
||||
|
||||
hist0 = calc_hist(img0_gray)
|
||||
hist1 = calc_hist(img1_gray)
|
||||
diff = torch.abs(hist0 - hist1)
|
||||
diff[diff < 0.01] = 0
|
||||
if torch.sum(diff) > 0.8 or diff.max() > 0.4:
|
||||
return True
|
||||
img0_gray = img0_gray.to(device)
|
||||
img1_gray = img1_gray.to(device)
|
||||
|
||||
# second stage: detect mv and pix mismatch
|
||||
|
||||
(n, c, h, w) = F01_tensor.size()
|
||||
scale_x = w / 1920
|
||||
scale_y = h / 1080
|
||||
|
||||
# compare mv
|
||||
(y, x) = torch.meshgrid(torch.arange(h), torch.arange(w))
|
||||
(y_grid, x_grid) = torch.meshgrid(
|
||||
torch.arange(64, h - 64, 8), torch.arange(64, w - 64, 8))
|
||||
x = x.to(device)
|
||||
y = y.to(device)
|
||||
y_grid = y_grid.to(device)
|
||||
x_grid = x_grid.to(device)
|
||||
fx = F01_tensor[0, 0]
|
||||
fy = F01_tensor[0, 1]
|
||||
x_ = x.float() + fx
|
||||
y_ = y.float() + fy
|
||||
x_ = torch.clamp(x_ + 0.5, 0, w - 1).long()
|
||||
y_ = torch.clamp(y_ + 0.5, 0, h - 1).long()
|
||||
|
||||
grid_fx = fx[y_grid, x_grid]
|
||||
grid_fy = fy[y_grid, x_grid]
|
||||
|
||||
x_grid_ = x_[y_grid, x_grid]
|
||||
y_grid_ = y_[y_grid, x_grid]
|
||||
|
||||
grid_fx_ = F10_tensor[0, 0, y_grid_, x_grid_]
|
||||
grid_fy_ = F10_tensor[0, 1, y_grid_, x_grid_]
|
||||
|
||||
sum_x = grid_fx + grid_fx_
|
||||
sum_y = grid_fy + grid_fy_
|
||||
distance = torch.sqrt(sum_x**2 + sum_y**2)
|
||||
|
||||
fx_len = torch.abs(grid_fx) * scale_x
|
||||
fy_len = torch.abs(grid_fy) * scale_y
|
||||
ori_len = torch.where(fx_len > fy_len, fx_len, fy_len)
|
||||
|
||||
thres = torch.clamp(0.1 * ori_len + 4, 5, 14)
|
||||
|
||||
# compare pix diff
|
||||
ori_img = img0_gray
|
||||
ref_img = img1_gray[:, :, y_, x_]
|
||||
|
||||
img_diff = ori_img.float() - ref_img.float()
|
||||
img_diff = torch.abs(img_diff)
|
||||
|
||||
kernel = np.ones([8, 8], np.float) / 64
|
||||
kernel = torch.FloatTensor(kernel).to(device).unsqueeze(0).unsqueeze(0)
|
||||
diff = F.conv2d(img_diff, kernel, padding=4)
|
||||
|
||||
diff = diff[0, 0, y_grid, x_grid]
|
||||
|
||||
index = (distance > thres) * (diff > 5)
|
||||
if index.sum().float() / distance.numel() > 0.5:
|
||||
scene_change = True
|
||||
return scene_change
|
||||
@@ -0,0 +1,96 @@
|
||||
# The implementation is adopted from RAFT,
|
||||
# made publicly available under the BSD-3-Clause license at https://github.com/princeton-vl/RAFT
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy import interpolate
|
||||
|
||||
|
||||
class InputPadder:
|
||||
""" Pads images such that dimensions are divisible by 8 """
|
||||
|
||||
def __init__(self, dims, mode='sintel'):
|
||||
self.ht, self.wd = dims[-2:]
|
||||
pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
|
||||
pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
|
||||
if mode == 'sintel':
|
||||
self._pad = [
|
||||
pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2,
|
||||
pad_ht - pad_ht // 2
|
||||
]
|
||||
else:
|
||||
self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
|
||||
|
||||
def pad(self, *inputs):
|
||||
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
||||
|
||||
def unpad(self, x):
|
||||
ht, wd = x.shape[-2:]
|
||||
c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
|
||||
return x[..., c[0]:c[1], c[2]:c[3]]
|
||||
|
||||
|
||||
def forward_interpolate(flow):
|
||||
flow = flow.detach().cpu().numpy()
|
||||
dx, dy = flow[0], flow[1]
|
||||
|
||||
ht, wd = dx.shape
|
||||
x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
|
||||
|
||||
x1 = x0 + dx
|
||||
y1 = y0 + dy
|
||||
|
||||
x1 = x1.reshape(-1)
|
||||
y1 = y1.reshape(-1)
|
||||
dx = dx.reshape(-1)
|
||||
dy = dy.reshape(-1)
|
||||
|
||||
valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
|
||||
x1 = x1[valid]
|
||||
y1 = y1[valid]
|
||||
dx = dx[valid]
|
||||
dy = dy[valid]
|
||||
|
||||
flow_x = interpolate.griddata((x1, y1),
|
||||
dx, (x0, y0),
|
||||
method='nearest',
|
||||
fill_value=0)
|
||||
|
||||
flow_y = interpolate.griddata((x1, y1),
|
||||
dy, (x0, y0),
|
||||
method='nearest',
|
||||
fill_value=0)
|
||||
|
||||
flow = np.stack([flow_x, flow_y], axis=0)
|
||||
return torch.from_numpy(flow).float()
|
||||
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
||||
xgrid = 2 * xgrid / (W - 1) - 1
|
||||
ygrid = 2 * ygrid / (H - 1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(
|
||||
torch.arange(ht, device=device), torch.arange(wd, device=device))
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
|
||||
def upflow8(flow, mode='bilinear'):
|
||||
new_size = (8 * flow.shape[2], 8 * flow.shape[3])
|
||||
return 8 * F.interpolate(
|
||||
flow, size=new_size, mode=mode, align_corners=True)
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .video_frame_interpolation_dataset import VideoFrameInterpolationDataset
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'video_frame_interpolation_dataset':
|
||||
['VideoFrameInterpolationDataset'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
# ------------------------------------------------------------------------
|
||||
# Modified from BasicSR (https://github.com/xinntao/BasicSR)
|
||||
# Copyright 2018-2020 BasicSR Authors
|
||||
# ------------------------------------------------------------------------
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def img2tensor(imgs, bgr2rgb=True, float32=True):
|
||||
"""Numpy array to tensor.
|
||||
Args:
|
||||
imgs (list[ndarray] | ndarray): Input images.
|
||||
bgr2rgb (bool): Whether to change bgr to rgb.
|
||||
float32 (bool): Whether to change to float32.
|
||||
Returns:
|
||||
list[tensor] | tensor: Tensor images. If returned results only have
|
||||
one element, just return tensor.
|
||||
"""
|
||||
|
||||
def _totensor(img, bgr2rgb, float32):
|
||||
if img.shape[2] == 3 and bgr2rgb:
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = torch.from_numpy(img.transpose(2, 0, 1))
|
||||
if float32:
|
||||
img = img.float()
|
||||
return img
|
||||
|
||||
if isinstance(imgs, list):
|
||||
return [_totensor(img, bgr2rgb, float32) for img in imgs]
|
||||
else:
|
||||
return _totensor(imgs, bgr2rgb, float32)
|
||||
|
||||
|
||||
def img_padding(img_tensor, height, width, pad_num=32):
|
||||
ph = ((height - 1) // pad_num + 1) * pad_num
|
||||
pw = ((width - 1) // pad_num + 1) * pad_num
|
||||
padding = (0, pw - width, 0, ph - height)
|
||||
img_tensor = F.pad(img_tensor, padding)
|
||||
return img_tensor
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.msdatasets.task_datasets.builder import TASK_DATASETS
|
||||
from modelscope.msdatasets.task_datasets.torch_base_dataset import \
|
||||
TorchTaskDataset
|
||||
from modelscope.msdatasets.task_datasets.video_frame_interpolation.data_utils import (
|
||||
img2tensor, img_padding)
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
def default_loader(path):
|
||||
return cv2.imread(path, cv2.IMREAD_UNCHANGED).astype(np.float32)
|
||||
|
||||
|
||||
@TASK_DATASETS.register_module(
|
||||
Tasks.video_frame_interpolation,
|
||||
module_name=Models.video_frame_interpolation)
|
||||
class VideoFrameInterpolationDataset(TorchTaskDataset):
|
||||
"""Dataset for video frame-interpolation.
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, opt):
|
||||
self.dataset = dataset
|
||||
self.opt = opt
|
||||
|
||||
def __len__(self):
|
||||
return len(self.dataset)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
# Load frames. Dimension order: HWC; channel order: BGR;
|
||||
# image range: [0, 1], float32
|
||||
item_dict = self.dataset[index]
|
||||
img0 = default_loader(item_dict['Input1:FILE'])
|
||||
img1 = default_loader(item_dict['Input2:FILE'])
|
||||
img2 = default_loader(item_dict['Input3:FILE'])
|
||||
img3 = default_loader(item_dict['Input4:FILE'])
|
||||
gt = default_loader(item_dict['Output:FILE'])
|
||||
|
||||
img0, img1, img2, img3, gt = img2tensor([img0, img1, img2, img3, gt],
|
||||
bgr2rgb=False,
|
||||
float32=True)
|
||||
|
||||
imgs = torch.cat((img0, img1, img2, img3), dim=0)
|
||||
height, width = imgs.size(1), imgs.size(2)
|
||||
imgs = img_padding(imgs, height, width, pad_num=32)
|
||||
return {'input': imgs, 'target': gt / 255.0}
|
||||
@@ -305,6 +305,7 @@ TASK_OUTPUTS = {
|
||||
|
||||
# video editing task result for a single video
|
||||
# {"output_video": "path_to_rendered_video"}
|
||||
Tasks.video_frame_interpolation: [OutputKeys.OUTPUT_VIDEO],
|
||||
Tasks.video_super_resolution: [OutputKeys.OUTPUT_VIDEO],
|
||||
|
||||
# live category recognition result for single video
|
||||
|
||||
@@ -226,6 +226,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
'damo/cv_video-inpainting'),
|
||||
Tasks.video_human_matting: (Pipelines.video_human_matting,
|
||||
'damo/cv_effnetv2_video-human-matting'),
|
||||
Tasks.video_frame_interpolation:
|
||||
(Pipelines.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation'),
|
||||
Tasks.human_wholebody_keypoint:
|
||||
(Pipelines.human_wholebody_keypoint,
|
||||
'damo/cv_hrnetw48_human-wholebody-keypoint_image'),
|
||||
@@ -247,9 +250,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.translation_evaluation:
|
||||
(Pipelines.translation_evaluation,
|
||||
'damo/nlp_unite_mup_translation_evaluation_multilingual_large'),
|
||||
Tasks.video_object_segmentation:
|
||||
(Pipelines.video_object_segmentation,
|
||||
'damo/cv_rdevos_video-object-segmentation'),
|
||||
Tasks.video_object_segmentation: (
|
||||
Pipelines.video_object_segmentation,
|
||||
'damo/cv_rdevos_video-object-segmentation'),
|
||||
Tasks.video_multi_object_tracking: (
|
||||
Pipelines.video_multi_object_tracking,
|
||||
'damo/cv_yolov5_video-multi-object-tracking_fairmot'),
|
||||
|
||||
@@ -68,6 +68,7 @@ if TYPE_CHECKING:
|
||||
from .hand_static_pipeline import HandStaticPipeline
|
||||
from .referring_video_object_segmentation_pipeline import ReferringVideoObjectSegmentationPipeline
|
||||
from .language_guided_video_summarization_pipeline import LanguageGuidedVideoSummarizationPipeline
|
||||
from .video_frame_interpolation_pipeline import VideoFrameInterpolationPipeline
|
||||
from .image_skychange_pipeline import ImageSkychangePipeline
|
||||
from .video_object_segmentation_pipeline import VideoObjectSegmentationPipeline
|
||||
from .video_stabilization_pipeline import VideoStabilizationPipeline
|
||||
@@ -164,6 +165,9 @@ else:
|
||||
'language_guided_video_summarization_pipeline': [
|
||||
'LanguageGuidedVideoSummarizationPipeline'
|
||||
],
|
||||
'video_frame_interpolation_pipeline': [
|
||||
'VideoFrameInterpolationPipeline'
|
||||
],
|
||||
'image_skychange_pipeline': ['ImageSkychangePipeline'],
|
||||
'video_object_segmentation_pipeline': [
|
||||
'VideoObjectSegmentationPipeline'
|
||||
|
||||
613
modelscope/pipelines/cv/video_frame_interpolation_pipeline.py
Normal file
613
modelscope/pipelines/cv/video_frame_interpolation_pipeline.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import os.path as osp
|
||||
import subprocess
|
||||
import tempfile
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.models.cv.video_frame_interpolation.utils.scene_change_detection import \
|
||||
do_scene_detect
|
||||
from modelscope.models.cv.video_frame_interpolation.VFINet_for_video_frame_interpolation import \
|
||||
VFINetForVideoFrameInterpolation
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.preprocessors import LoadImage
|
||||
from modelscope.preprocessors.cv import VideoReader
|
||||
from modelscope.utils.config import Config
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
VIDEO_EXTENSIONS = ('.mp4', '.mov')
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def img_trans(img_tensor): # in format of RGB
|
||||
img_tensor = img_tensor / 255.0
|
||||
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1,
|
||||
1).type_as(img_tensor)
|
||||
img_tensor -= mean
|
||||
return img_tensor
|
||||
|
||||
|
||||
def add_mean(x):
|
||||
mean = torch.Tensor([0.429, 0.431, 0.397]).view(1, 3, 1, 1).type_as(x)
|
||||
return x + mean
|
||||
|
||||
|
||||
def img_padding(img_tensor, height, width, pad_num=32):
|
||||
ph = ((height - 1) // pad_num + 1) * pad_num
|
||||
pw = ((width - 1) // pad_num + 1) * pad_num
|
||||
padding = (0, pw - width, 0, ph - height)
|
||||
img_tensor = F.pad(img_tensor, padding)
|
||||
return img_tensor
|
||||
|
||||
|
||||
def do_inference_lowers(flow_10,
|
||||
flow_12,
|
||||
flow_21,
|
||||
flow_23,
|
||||
img1,
|
||||
img2,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=False):
|
||||
# given frame1, frame2 and optical flow, predict frame_t
|
||||
if start_end_flag:
|
||||
read_count -= 1
|
||||
else:
|
||||
read_count -= 2
|
||||
while inter_count <= read_count:
|
||||
t = inter_count + 1 - read_count
|
||||
t = round(t, 2)
|
||||
if (t - 0) < delta / 2:
|
||||
output = img1
|
||||
elif (1 - t) < delta / 2:
|
||||
output = img2
|
||||
else:
|
||||
output = inter_model(flow_10, flow_12, flow_21, flow_23, img1,
|
||||
img2, t)
|
||||
|
||||
output = 255 * add_mean(output)
|
||||
outputs.append(output)
|
||||
inter_count += delta
|
||||
|
||||
return outputs, inter_count
|
||||
|
||||
|
||||
def do_inference_highers(flow_10,
|
||||
flow_12,
|
||||
flow_21,
|
||||
flow_23,
|
||||
img1,
|
||||
img2,
|
||||
img1_up,
|
||||
img2_up,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=False):
|
||||
# given frame1, frame2 and optical flow, predict frame_t. For videos with a resolution of 2k and above
|
||||
if start_end_flag:
|
||||
read_count -= 1
|
||||
else:
|
||||
read_count -= 2
|
||||
while inter_count <= read_count:
|
||||
t = inter_count + 1 - read_count
|
||||
t = round(t, 2)
|
||||
if (t - 0) < delta / 2:
|
||||
output = img1_up
|
||||
elif (1 - t) < delta / 2:
|
||||
output = img2_up
|
||||
else:
|
||||
output = inter_model(flow_10, flow_12, flow_21, flow_23, img1,
|
||||
img2, img1_up, img2_up, t)
|
||||
|
||||
output = 255 * add_mean(output)
|
||||
outputs.append(output)
|
||||
inter_count += delta
|
||||
|
||||
return outputs, inter_count
|
||||
|
||||
|
||||
def inference_lowers(flow_model, refine_model, inter_model, video_len,
|
||||
read_count, inter_count, delta, scene_change_flag,
|
||||
img_tensor_list, img_ori_list, inputs, outputs):
|
||||
# given a video with a resolution less than 2k and output fps, execute the video frame interpolation function.
|
||||
height, width = inputs[read_count].size(2), inputs[read_count].size(3)
|
||||
# We use four consecutive frames to do frame interpolation. flow_10 represents
|
||||
# optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and
|
||||
# flow_23.
|
||||
flow_10 = None
|
||||
flow_12 = None
|
||||
flow_21 = None
|
||||
flow_23 = None
|
||||
with torch.no_grad():
|
||||
while (read_count < video_len):
|
||||
img = inputs[read_count]
|
||||
img = img_padding(img, height, width)
|
||||
img_ori_list.append(img)
|
||||
img_tensor_list.append(img_trans(img))
|
||||
read_count += 1
|
||||
if len(img_tensor_list) == 2:
|
||||
img0 = img_tensor_list[0]
|
||||
img1 = img_tensor_list[1]
|
||||
img0_ori = img_ori_list[0]
|
||||
img1_ori = img_ori_list[1]
|
||||
_, flow_01_up = flow_model(
|
||||
img0_ori, img1_ori, iters=12, test_mode=True)
|
||||
_, flow_10_up = flow_model(
|
||||
img1_ori, img0_ori, iters=12, test_mode=True)
|
||||
flow_01, flow_10 = refine_model(img0, img1, flow_01_up,
|
||||
flow_10_up, 2)
|
||||
scene_change_flag[0] = do_scene_detect(
|
||||
flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height,
|
||||
0:width],
|
||||
img_ori_list[0][:, :, 0:height, 0:width],
|
||||
img_ori_list[1][:, :, 0:height, 0:width])
|
||||
if scene_change_flag[0]:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
img0,
|
||||
img1,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
else:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None,
|
||||
flow_01,
|
||||
flow_10,
|
||||
None,
|
||||
img0,
|
||||
img1,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
|
||||
if len(img_tensor_list) == 4:
|
||||
if flow_12 is None or flow_21 is None:
|
||||
img2 = img_tensor_list[2]
|
||||
img2_ori = img_ori_list[2]
|
||||
_, flow_12_up = flow_model(
|
||||
img1_ori, img2_ori, iters=12, test_mode=True)
|
||||
_, flow_21_up = flow_model(
|
||||
img2_ori, img1_ori, iters=12, test_mode=True)
|
||||
flow_12, flow_21 = refine_model(img1, img2, flow_12_up,
|
||||
flow_21_up, 2)
|
||||
scene_change_flag[1] = do_scene_detect(
|
||||
flow_12[:, :, 0:height,
|
||||
0:width], flow_21[:, :, 0:height, 0:width],
|
||||
img_ori_list[1][:, :, 0:height, 0:width],
|
||||
img_ori_list[2][:, :, 0:height, 0:width])
|
||||
|
||||
img3 = img_tensor_list[3]
|
||||
img3_ori = img_ori_list[3]
|
||||
_, flow_23_up = flow_model(
|
||||
img2_ori, img3_ori, iters=12, test_mode=True)
|
||||
_, flow_32_up = flow_model(
|
||||
img3_ori, img2_ori, iters=12, test_mode=True)
|
||||
flow_23, flow_32 = refine_model(img2, img3, flow_23_up,
|
||||
flow_32_up, 2)
|
||||
scene_change_flag[2] = do_scene_detect(
|
||||
flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height,
|
||||
0:width],
|
||||
img_ori_list[2][:, :, 0:height, 0:width],
|
||||
img_ori_list[3][:, :, 0:height, 0:width])
|
||||
|
||||
if scene_change_flag[1]:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None, None, None, None, img1, img2, inter_model,
|
||||
read_count, inter_count, delta, outputs)
|
||||
elif scene_change_flag[0] or scene_change_flag[2]:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None, flow_12, flow_21, None, img1, img2, inter_model,
|
||||
read_count, inter_count, delta, outputs)
|
||||
else:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
flow_10_up, flow_12, flow_21, flow_23_up, img1, img2,
|
||||
inter_model, read_count, inter_count, delta, outputs)
|
||||
|
||||
img_tensor_list.pop(0)
|
||||
img_ori_list.pop(0)
|
||||
|
||||
# for next group
|
||||
img1 = img2
|
||||
img2 = img3
|
||||
img1_ori = img2_ori
|
||||
img2_ori = img3_ori
|
||||
flow_10 = flow_21
|
||||
flow_12 = flow_23
|
||||
flow_21 = flow_32
|
||||
|
||||
flow_10_up = flow_21_up
|
||||
flow_12_up = flow_23_up
|
||||
flow_21_up = flow_32_up
|
||||
|
||||
# save scene change flag for next group
|
||||
scene_change_flag[0] = scene_change_flag[1]
|
||||
scene_change_flag[1] = scene_change_flag[2]
|
||||
scene_change_flag[2] = False
|
||||
|
||||
if read_count > 0: # the last remaining 3 images
|
||||
img_ori_list.pop(0)
|
||||
img_tensor_list.pop(0)
|
||||
assert (len(img_tensor_list) == 2)
|
||||
|
||||
if scene_change_flag[1]:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
img1,
|
||||
img2,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
else:
|
||||
outputs, inter_count = do_inference_lowers(
|
||||
None,
|
||||
flow_12,
|
||||
flow_21,
|
||||
None,
|
||||
img1,
|
||||
img2,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def inference_highers(flow_model, refine_model, inter_model, video_len,
|
||||
read_count, inter_count, delta, scene_change_flag,
|
||||
img_tensor_list, img_ori_list, inputs, outputs):
|
||||
# given a video with a resolution of 2k or above and output fps, execute the video frame interpolation function.
|
||||
if inputs[read_count].size(2) % 2 != 0 or inputs[read_count].size(
|
||||
3) % 2 != 0:
|
||||
raise RuntimeError('Video width and height must be even')
|
||||
|
||||
height, width = inputs[read_count].size(2) // 2, inputs[read_count].size(
|
||||
3) // 2
|
||||
# We use four consecutive frames to do frame interpolation. flow_10 represents
|
||||
# optical flow from frame0 to frame1. The similar goes for flow_12, flow_21 and
|
||||
# flow_23.
|
||||
flow_10 = None
|
||||
flow_12 = None
|
||||
flow_21 = None
|
||||
flow_23 = None
|
||||
img_up_list = []
|
||||
with torch.no_grad():
|
||||
while (read_count < video_len):
|
||||
img_up = inputs[read_count]
|
||||
img_up = img_padding(img_up, height * 2, width * 2, pad_num=64)
|
||||
img = F.interpolate(
|
||||
img_up, scale_factor=0.5, mode='bilinear', align_corners=False)
|
||||
|
||||
img_up_list.append(img_trans(img_up))
|
||||
img_ori_list.append(img)
|
||||
img_tensor_list.append(img_trans(img))
|
||||
read_count += 1
|
||||
if len(img_tensor_list) == 2:
|
||||
img0 = img_tensor_list[0]
|
||||
img1 = img_tensor_list[1]
|
||||
img0_ori = img_ori_list[0]
|
||||
img1_ori = img_ori_list[1]
|
||||
img0_up = img_up_list[0]
|
||||
img1_up = img_up_list[1]
|
||||
_, flow_01_up = flow_model(
|
||||
img0_ori, img1_ori, iters=12, test_mode=True)
|
||||
_, flow_10_up = flow_model(
|
||||
img1_ori, img0_ori, iters=12, test_mode=True)
|
||||
flow_01, flow_10 = refine_model(img0, img1, flow_01_up,
|
||||
flow_10_up, 2)
|
||||
scene_change_flag[0] = do_scene_detect(
|
||||
flow_01[:, :, 0:height, 0:width], flow_10[:, :, 0:height,
|
||||
0:width],
|
||||
img_ori_list[0][:, :, 0:height, 0:width],
|
||||
img_ori_list[1][:, :, 0:height, 0:width])
|
||||
if scene_change_flag[0]:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
img0,
|
||||
img1,
|
||||
img0_up,
|
||||
img1_up,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
else:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None,
|
||||
flow_01,
|
||||
flow_10,
|
||||
None,
|
||||
img0,
|
||||
img1,
|
||||
img0_up,
|
||||
img1_up,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
|
||||
if len(img_tensor_list) == 4:
|
||||
if flow_12 is None or flow_21 is None:
|
||||
img2 = img_tensor_list[2]
|
||||
img2_ori = img_ori_list[2]
|
||||
img2_up = img_up_list[2]
|
||||
_, flow_12_up = flow_model(
|
||||
img1_ori, img2_ori, iters=12, test_mode=True)
|
||||
_, flow_21_up = flow_model(
|
||||
img2_ori, img1_ori, iters=12, test_mode=True)
|
||||
flow_12, flow_21 = refine_model(img1, img2, flow_12_up,
|
||||
flow_21_up, 2)
|
||||
scene_change_flag[1] = do_scene_detect(
|
||||
flow_12[:, :, 0:height,
|
||||
0:width], flow_21[:, :, 0:height, 0:width],
|
||||
img_ori_list[1][:, :, 0:height, 0:width],
|
||||
img_ori_list[2][:, :, 0:height, 0:width])
|
||||
|
||||
img3 = img_tensor_list[3]
|
||||
img3_ori = img_ori_list[3]
|
||||
img3_up = img_up_list[3]
|
||||
_, flow_23_up = flow_model(
|
||||
img2_ori, img3_ori, iters=12, test_mode=True)
|
||||
_, flow_32_up = flow_model(
|
||||
img3_ori, img2_ori, iters=12, test_mode=True)
|
||||
flow_23, flow_32 = refine_model(img2, img3, flow_23_up,
|
||||
flow_32_up, 2)
|
||||
scene_change_flag[2] = do_scene_detect(
|
||||
flow_23[:, :, 0:height, 0:width], flow_32[:, :, 0:height,
|
||||
0:width],
|
||||
img_ori_list[2][:, :, 0:height, 0:width],
|
||||
img_ori_list[3][:, :, 0:height, 0:width])
|
||||
|
||||
if scene_change_flag[1]:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None, None, None, None, img1, img2, img1_up, img2_up,
|
||||
inter_model, read_count, inter_count, delta, outputs)
|
||||
elif scene_change_flag[0] or scene_change_flag[2]:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None, flow_12, flow_21, None, img1, img2, img1_up,
|
||||
img2_up, inter_model, read_count, inter_count, delta,
|
||||
outputs)
|
||||
else:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
flow_10_up, flow_12, flow_21, flow_23_up, img1, img2,
|
||||
img1_up, img2_up, inter_model, read_count, inter_count,
|
||||
delta, outputs)
|
||||
|
||||
img_up_list.pop(0)
|
||||
img_tensor_list.pop(0)
|
||||
img_ori_list.pop(0)
|
||||
|
||||
# for next group
|
||||
img1 = img2
|
||||
img2 = img3
|
||||
img1_ori = img2_ori
|
||||
img2_ori = img3_ori
|
||||
img1_up = img2_up
|
||||
img2_up = img3_up
|
||||
flow_10 = flow_21
|
||||
flow_12 = flow_23
|
||||
flow_21 = flow_32
|
||||
|
||||
flow_10_up = flow_21_up
|
||||
flow_12_up = flow_23_up
|
||||
flow_21_up = flow_32_up
|
||||
|
||||
# save scene change flag for next group
|
||||
scene_change_flag[0] = scene_change_flag[1]
|
||||
scene_change_flag[1] = scene_change_flag[2]
|
||||
scene_change_flag[2] = False
|
||||
|
||||
if read_count > 0: # the last remaining 3 images
|
||||
img_ori_list.pop(0)
|
||||
img_tensor_list.pop(0)
|
||||
assert (len(img_tensor_list) == 2)
|
||||
|
||||
if scene_change_flag[1]:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
img1,
|
||||
img2,
|
||||
img1_up,
|
||||
img2_up,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
else:
|
||||
outputs, inter_count = do_inference_highers(
|
||||
None,
|
||||
flow_12,
|
||||
flow_21,
|
||||
None,
|
||||
img1,
|
||||
img2,
|
||||
img1_up,
|
||||
img2_up,
|
||||
inter_model,
|
||||
read_count,
|
||||
inter_count,
|
||||
delta,
|
||||
outputs,
|
||||
start_end_flag=True)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def convert(param):
|
||||
return {
|
||||
k.replace('module.', ''): v
|
||||
for k, v in param.items() if 'module.' in k
|
||||
}
|
||||
|
||||
|
||||
__all__ = ['VideoFrameInterpolationPipeline']
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.video_frame_interpolation,
|
||||
module_name=Pipelines.video_frame_interpolation)
|
||||
class VideoFrameInterpolationPipeline(Pipeline):
|
||||
""" Video Frame Interpolation Pipeline.
|
||||
Example:
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> from modelscope.utils.constant import Tasks
|
||||
>>> from modelscope.outputs import OutputKeys
|
||||
|
||||
>>> video = 'https://modelscope.oss-cn-beijing.aliyuncs.com/test/videos/video_frame_interpolation_test.mp4'
|
||||
>>> video_frame_interpolation_pipeline = pipeline(Tasks.video_frame_interpolation,
|
||||
'damo/cv_raft_video-frame-interpolation')
|
||||
>>> result = video_frame_interpolation_pipeline(video)[OutputKeys.OUTPUT_VIDEO]
|
||||
>>> print('pipeline: the output video path is {}'.format(result))
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: Union[VFINetForVideoFrameInterpolation, str],
|
||||
preprocessor=None,
|
||||
**kwargs):
|
||||
super().__init__(model=model, preprocessor=preprocessor, **kwargs)
|
||||
if torch.cuda.is_available():
|
||||
self._device = torch.device('cuda')
|
||||
else:
|
||||
self._device = torch.device('cpu')
|
||||
self.net = self.model.model
|
||||
self.net.to(self._device)
|
||||
self.net.eval()
|
||||
logger.info('load video frame-interpolation done')
|
||||
|
||||
def preprocess(self, input: Input, out_fps: float = 0) -> Dict[str, Any]:
|
||||
# read images
|
||||
file_extension = os.path.splitext(input)[1]
|
||||
if file_extension in VIDEO_EXTENSIONS: # input is a video file
|
||||
video_reader = VideoReader(input)
|
||||
inputs = []
|
||||
for frame in video_reader:
|
||||
inputs.append(frame)
|
||||
fps = video_reader.fps
|
||||
elif file_extension == '': # input is a directory
|
||||
inputs = []
|
||||
input_paths = sorted(glob.glob(f'{input}/*'))
|
||||
for input_path in input_paths:
|
||||
img = LoadImage(input_path, mode='rgb')
|
||||
inputs.append(img)
|
||||
fps = 25 # default fps
|
||||
else:
|
||||
raise ValueError('"input" can only be a video or a directory.')
|
||||
|
||||
for i, img in enumerate(inputs):
|
||||
img = torch.from_numpy(img.copy()).permute(2, 0, 1).float()
|
||||
inputs[i] = img.unsqueeze(0)
|
||||
|
||||
if out_fps == 0:
|
||||
out_fps = 2 * fps
|
||||
return {'video': inputs, 'fps': fps, 'out_fps': out_fps}
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
inputs = input['video']
|
||||
fps = input['fps']
|
||||
out_fps = input['out_fps']
|
||||
video_len = len(inputs)
|
||||
|
||||
flow_model = self.net.flownet
|
||||
refine_model = self.net.internet.ifnet
|
||||
|
||||
read_count = 0
|
||||
inter_count = 0
|
||||
delta = fps / out_fps
|
||||
scene_change_flag = [False, False, False]
|
||||
img_tensor_list = []
|
||||
img_ori_list = []
|
||||
outputs = []
|
||||
height, width = inputs[read_count].size(2), inputs[read_count].size(3)
|
||||
if height >= 1440 or width >= 2560:
|
||||
inter_model = self.net.internet_Ds.internet
|
||||
outputs = inference_highers(flow_model, refine_model, inter_model,
|
||||
video_len, read_count, inter_count,
|
||||
delta, scene_change_flag,
|
||||
img_tensor_list, img_ori_list, inputs,
|
||||
outputs)
|
||||
else:
|
||||
inter_model = self.net.internet.internet
|
||||
outputs = inference_lowers(flow_model, refine_model, inter_model,
|
||||
video_len, read_count, inter_count,
|
||||
delta, scene_change_flag,
|
||||
img_tensor_list, img_ori_list, inputs,
|
||||
outputs)
|
||||
|
||||
for i in range(len(outputs)):
|
||||
outputs[i] = outputs[i][:, :, 0:height, 0:width]
|
||||
return {'output': outputs, 'fps': out_fps}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
|
||||
output_video_path = kwargs.get('output_video', None)
|
||||
demo_service = kwargs.get('demo_service', True)
|
||||
if output_video_path is None:
|
||||
output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
|
||||
h, w = inputs['output'][0].shape[-2:]
|
||||
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
||||
video_writer = cv2.VideoWriter(output_video_path, fourcc,
|
||||
inputs['fps'], (w, h))
|
||||
for i in range(len(inputs['output'])):
|
||||
img = inputs['output'][i]
|
||||
img = img[0].permute(1, 2, 0).byte().cpu().numpy()
|
||||
video_writer.write(img.astype(np.uint8))
|
||||
|
||||
video_writer.release()
|
||||
if demo_service:
|
||||
assert os.system(
|
||||
'ffmpeg -version') == 0, 'ffmpeg is not installed correctly!'
|
||||
output_video_path_for_web = output_video_path[:-4] + '_web.mp4'
|
||||
convert_cmd = f'ffmpeg -i {output_video_path} -vcodec h264 -crf 5 {output_video_path_for_web}'
|
||||
subprocess.call(convert_cmd, shell=True)
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path_for_web}
|
||||
else:
|
||||
return {OutputKeys.OUTPUT_VIDEO: output_video_path}
|
||||
@@ -98,6 +98,7 @@ class CVTasks(object):
|
||||
|
||||
# video editing
|
||||
video_inpainting = 'video-inpainting'
|
||||
video_frame_interpolation = 'video-frame-interpolation'
|
||||
video_stabilization = 'video-stabilization'
|
||||
video_super_resolution = 'video-super-resolution'
|
||||
|
||||
|
||||
52
tests/pipelines/test_video_frame_interpolation.py
Normal file
52
tests/pipelines/test_video_frame_interpolation.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from modelscope.hub.snapshot_download import snapshot_download
|
||||
from modelscope.models import Model
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.pipelines.cv import VideoFrameInterpolationPipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class VideoFrameInterpolationTest(unittest.TestCase, DemoCompatibilityCheck):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = Tasks.video_frame_interpolation
|
||||
self.model_id = 'damo/cv_raft_video-frame-interpolation'
|
||||
self.test_video = 'data/test/videos/video_frame_interpolation_test.mp4'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_by_direct_model_download(self):
|
||||
cache_path = snapshot_download(self.model_id)
|
||||
pipeline = VideoFrameInterpolationPipeline(cache_path)
|
||||
pipeline.group_key = self.task
|
||||
out_video_path = pipeline(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_run_with_model_from_modelhub(self):
|
||||
pipeline_ins = pipeline(
|
||||
task=Tasks.video_frame_interpolation, model=self.model_id)
|
||||
out_video_path = pipeline_ins(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
||||
def test_run_with_default_model(self):
|
||||
pipeline_ins = pipeline(task=Tasks.video_frame_interpolation)
|
||||
out_video_path = pipeline_ins(
|
||||
input=self.test_video)[OutputKeys.OUTPUT_VIDEO]
|
||||
print('pipeline: the output video path is {}'.format(out_video_path))
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
self.compatibility_check()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user