mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2026-02-24 20:19:59 +01:00
150 lines
5.9 KiB
Python
150 lines
5.9 KiB
Python
"""
|
|
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
|
|
"""
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
|
|
from mmengine.model import constant_init
|
|
|
|
from inpainter.model.modules.flow_comp import flow_warp
|
|
|
|
|
|
class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
|
|
"""Second-order deformable alignment module."""
|
|
def __init__(self, *args, **kwargs):
|
|
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
|
|
|
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
|
|
|
self.conv_offset = nn.Sequential(
|
|
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
|
)
|
|
|
|
self.init_offset()
|
|
|
|
def init_offset(self):
|
|
constant_init(self.conv_offset[-1], val=0, bias=0)
|
|
|
|
def forward(self, x, extra_feat, flow_1, flow_2):
|
|
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
|
out = self.conv_offset(extra_feat)
|
|
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
|
|
|
# offset
|
|
offset = self.max_residue_magnitude * torch.tanh(
|
|
torch.cat((o1, o2), dim=1))
|
|
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
|
offset_1 = offset_1 + flow_1.flip(1).repeat(1,
|
|
offset_1.size(1) // 2, 1,
|
|
1)
|
|
offset_2 = offset_2 + flow_2.flip(1).repeat(1,
|
|
offset_2.size(1) // 2, 1,
|
|
1)
|
|
offset = torch.cat([offset_1, offset_2], dim=1)
|
|
|
|
# mask
|
|
mask = torch.sigmoid(mask)
|
|
|
|
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
|
|
self.stride, self.padding,
|
|
self.dilation, self.groups,
|
|
self.deform_groups)
|
|
|
|
|
|
class BidirectionalPropagation(nn.Module):
|
|
def __init__(self, channel):
|
|
super(BidirectionalPropagation, self).__init__()
|
|
modules = ['backward_', 'forward_']
|
|
self.deform_align = nn.ModuleDict()
|
|
self.backbone = nn.ModuleDict()
|
|
self.channel = channel
|
|
|
|
for i, module in enumerate(modules):
|
|
self.deform_align[module] = SecondOrderDeformableAlignment(
|
|
2 * channel, channel, 3, padding=1, deform_groups=16)
|
|
|
|
self.backbone[module] = nn.Sequential(
|
|
nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
|
|
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
|
nn.Conv2d(channel, channel, 3, 1, 1),
|
|
)
|
|
|
|
self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
|
|
|
|
def forward(self, x, flows_backward, flows_forward):
|
|
"""
|
|
x shape : [b, t, c, h, w]
|
|
return [b, t, c, h, w]
|
|
"""
|
|
b, t, c, h, w = x.shape
|
|
feats = {}
|
|
feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
|
|
|
|
for module_name in ['backward_', 'forward_']:
|
|
|
|
feats[module_name] = []
|
|
|
|
frame_idx = range(0, t)
|
|
flow_idx = range(-1, t - 1)
|
|
mapping_idx = list(range(0, len(feats['spatial'])))
|
|
mapping_idx += mapping_idx[::-1]
|
|
|
|
if 'backward' in module_name:
|
|
frame_idx = frame_idx[::-1]
|
|
flows = flows_backward
|
|
else:
|
|
flows = flows_forward
|
|
|
|
feat_prop = x.new_zeros(b, self.channel, h, w)
|
|
for i, idx in enumerate(frame_idx):
|
|
feat_current = feats['spatial'][mapping_idx[idx]]
|
|
|
|
if i > 0:
|
|
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
|
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
|
|
|
# initialize second-order features
|
|
feat_n2 = torch.zeros_like(feat_prop)
|
|
flow_n2 = torch.zeros_like(flow_n1)
|
|
cond_n2 = torch.zeros_like(cond_n1)
|
|
if i > 1:
|
|
feat_n2 = feats[module_name][-2]
|
|
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
|
flow_n2 = flow_n1 + flow_warp(
|
|
flow_n2, flow_n1.permute(0, 2, 3, 1))
|
|
cond_n2 = flow_warp(feat_n2,
|
|
flow_n2.permute(0, 2, 3, 1))
|
|
|
|
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
|
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
|
feat_prop = self.deform_align[module_name](feat_prop, cond,
|
|
flow_n1,
|
|
flow_n2)
|
|
|
|
feat = [feat_current] + [
|
|
feats[k][idx]
|
|
for k in feats if k not in ['spatial', module_name]
|
|
] + [feat_prop]
|
|
|
|
feat = torch.cat(feat, dim=1)
|
|
feat_prop = feat_prop + self.backbone[module_name](feat)
|
|
feats[module_name].append(feat_prop)
|
|
|
|
if 'backward' in module_name:
|
|
feats[module_name] = feats[module_name][::-1]
|
|
|
|
outputs = []
|
|
for i in range(0, t):
|
|
align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
|
|
align_feats = torch.cat(align_feats, dim=1)
|
|
outputs.append(self.fusion(align_feats))
|
|
|
|
return torch.stack(outputs, dim=1) + x
|