add inpainting

This commit is contained in:
gaomingqi
2023-04-18 00:30:11 +08:00
parent 113299d99c
commit 0a6dd84ba8
11 changed files with 2876 additions and 0 deletions

160
inpainter/base_inpainter.py Normal file
View File

@@ -0,0 +1,160 @@
import os
import glob
from PIL import Image
import torch
import yaml
import cv2
import importlib
import numpy as np
from util.tensor_util import resize_frames, resize_masks
class BaseInpainter:
def __init__(self, E2FGVI_checkpoint, device) -> None:
"""
E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
"""
net = importlib.import_module('model.e2fgvi_hq')
self.model = net.InpaintGenerator().to(device)
self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
self.model.eval()
self.device = device
# load configurations
with open("inpainter/config/config.yaml", 'r') as stream:
config = yaml.safe_load(stream)
self.neighbor_stride = config['neighbor_stride']
self.num_ref = config['num_ref']
self.step = config['step']
# sample reference frames from the whole video
def get_ref_index(self, f, neighbor_ids, length):
ref_index = []
if self.num_ref == -1:
for i in range(0, length, self.step):
if i not in neighbor_ids:
ref_index.append(i)
else:
start_idx = max(0, f - self.step * (self.num_ref // 2))
end_idx = min(length, f + self.step * (self.num_ref // 2))
for i in range(start_idx, end_idx + 1, self.step):
if i not in neighbor_ids:
if len(ref_index) > self.num_ref:
break
ref_index.append(i)
return ref_index
def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
"""
frames: numpy array, T, H, W, 3
masks: numpy array, T, H, W
dilate_radius: radius when applying dilation on masks
ratio: down-sample ratio
Output:
inpainted_frames: numpy array, T, H, W, 3
"""
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
masks = masks.copy()
masks = np.clip(masks, 0, 1)
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
T, H, W = masks.shape
# size: (w, h)
if ratio == 1:
size = None
else:
size = (int(W*ratio), int(H*ratio))
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
binary_masks = resize_masks(masks, size)
frames = resize_frames(frames, size) # T, H, W, 3
# frames and binary_masks are numpy arrays
h, w = frames.shape[1:3]
video_length = T
# convert to tensor
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
imgs, masks = imgs.to(self.device), masks.to(self.device)
comp_frames = [None] * video_length
for f in range(0, video_length, self.neighbor_stride):
neighbor_ids = [
i for i in range(max(0, f - self.neighbor_stride),
min(video_length, f + self.neighbor_stride + 1))
]
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
with torch.no_grad():
masked_imgs = selected_imgs * (1 - selected_masks)
mod_size_h = 60
mod_size_w = 108
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
masked_imgs = torch.cat(
[masked_imgs, torch.flip(masked_imgs, [3])],
3)[:, :, :, :h + h_pad, :]
masked_imgs = torch.cat(
[masked_imgs, torch.flip(masked_imgs, [4])],
4)[:, :, :, :, :w + w_pad]
pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
pred_imgs = pred_imgs[:, :, :h, :w]
pred_imgs = (pred_imgs + 1) / 2
pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
for i in range(len(neighbor_ids)):
idx = neighbor_ids[i]
img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
1 - binary_masks[idx])
if comp_frames[idx] is None:
comp_frames[idx] = img
else:
comp_frames[idx] = comp_frames[idx].astype(
np.float32) * 0.5 + img.astype(np.float32) * 0.5
inpainted_frames = np.stack(comp_frames, 0)
return inpainted_frames.astype(np.uint8)
if __name__ == '__main__':
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
frame_path.sort()
mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
mask_path.sort()
save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
if not os.path.exists(save_path):
os.mkdir(save_path)
frames = []
masks = []
for fid, mid in zip(frame_path, mask_path):
frames.append(Image.open(fid).convert('RGB'))
masks.append(Image.open(mid).convert('P'))
frames = np.stack(frames, 0)
masks = np.stack(masks, 0)
# ----------------------------------------------
# how to use
# ----------------------------------------------
# 1/3: set checkpoint and device
checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
device = 'cuda:2'
# 2/3: initialise inpainter
base_inpainter = BaseInpainter(checkpoint, device)
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
# ratio: (0, 1], ratio for down sample, default value is 1
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.5) # numpy array, T, H, W, 3
# ----------------------------------------------
# end
# ----------------------------------------------
# save
for ti, inpainted_frame in enumerate(inpainted_frames):
frame = Image.fromarray(inpainted_frame).convert('RGB')
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))

View File

@@ -0,0 +1,4 @@
# config info for E2FGVI
neighbor_stride: 5
num_ref: -1
step: 10

350
inpainter/model/e2fgvi.py Normal file
View File

@@ -0,0 +1,350 @@
''' Towards An End-to-End Framework for Video Inpainting
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.modules.flow_comp import SPyNet
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
from model.modules.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
'Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' %
(type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' %
init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.group = [1, 2, 4, 8, 1]
self.layers = nn.ModuleList([
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True)
])
def forward(self, x):
bt, c, h, w = x.size()
h, w = h // 4, w // 4
out = x
for i, layer in enumerate(self.layers):
if i == 8:
x0 = out
if i > 8 and i % 2 == 0:
g = self.group[(i - 8) // 2]
x = x0.view(bt, g, -1, h, w)
o = out.view(bt, g, -1, h, w)
out = torch.cat([x, o], 2).view(bt, -1, h, w)
out = layer(out)
return out
class deconv(nn.Module):
def __init__(self,
input_channel,
output_channel,
kernel_size=3,
padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel,
output_channel,
kernel_size=kernel_size,
stride=1,
padding=padding)
def forward(self, x):
x = F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=True)
return self.conv(x)
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
hidden = 512
# encoder
self.encoder = Encoder()
# decoder
self.decoder = nn.Sequential(
deconv(channel // 2, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
# feature propagation module
self.feat_prop_module = BidirectionalPropagation(channel // 2)
# soft split and soft composition
kernel_size = (7, 7)
padding = (3, 3)
stride = (3, 3)
output_size = (60, 108)
t2t_params = {
'kernel_size': kernel_size,
'stride': stride,
'padding': padding,
'output_size': output_size
}
self.ss = SoftSplit(channel // 2,
hidden,
kernel_size,
stride,
padding,
t2t_param=t2t_params)
self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
stride, padding)
n_vecs = 1
for i, d in enumerate(kernel_size):
n_vecs *= int((output_size[i] + 2 * padding[i] -
(d - 1) - 1) / stride[i] + 1)
blocks = []
depths = 8
num_heads = [4] * depths
window_size = [(5, 9)] * depths
focal_windows = [(5, 9)] * depths
focal_levels = [2] * depths
pool_method = "fc"
for i in range(depths):
blocks.append(
TemporalFocalTransformerBlock(dim=hidden,
num_heads=num_heads[i],
window_size=window_size[i],
focal_level=focal_levels[i],
focal_window=focal_windows[i],
n_vecs=n_vecs,
t2t_params=t2t_params,
pool_method=pool_method))
self.transformer = nn.Sequential(*blocks)
if init_weights:
self.init_weights()
# Need to initial the weights of MSDeformAttn specifically
for m in self.modules():
if isinstance(m, SecondOrderDeformableAlignment):
m.init_offset()
# flow completion network
self.update_spynet = SPyNet()
def forward_bidirect_flow(self, masked_local_frames):
b, l_t, c, h, w = masked_local_frames.size()
# compute forward and backward flows of masked frames
masked_local_frames = F.interpolate(masked_local_frames.view(
-1, c, h, w),
scale_factor=1 / 4,
mode='bilinear',
align_corners=True,
recompute_scale_factor=True)
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
w // 4)
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
-1, c, h // 4, w // 4)
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
-1, c, h // 4, w // 4)
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
w // 4)
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
w // 4)
return pred_flows_forward, pred_flows_backward
def forward(self, masked_frames, num_local_frames):
l_t = num_local_frames
b, t, ori_c, ori_h, ori_w = masked_frames.size()
# normalization before feeding into the flow completion module
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
pred_flows = self.forward_bidirect_flow(masked_local_frames)
# extracting features and performing the feature propagation on local features
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
_, c, h, w = enc_feat.size()
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
pred_flows[1])
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
# content hallucination through stacking multiple temporal focal transformer blocks
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
trans_feat = self.transformer(trans_feat)
trans_feat = self.sc(trans_feat, t)
trans_feat = trans_feat.view(b, t, -1, h, w)
enc_feat = enc_feat + trans_feat
# decode frames from features
output = self.decoder(enc_feat.view(b * t, c, h, w))
output = torch.tanh(output)
return output, pred_flows
# ######################################################################
# Discriminator for Temporal Patch GAN
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self,
in_channels=3,
use_sigmoid=False,
use_spectral_norm=True,
init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 32
self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(in_channels=in_channels,
out_channels=nf * 1,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=1,
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 1,
nf * 2,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 2,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2)))
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape (old)
# B, T, C, H, W (new)
xs_t = torch.transpose(xs, 1, 2)
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module

View File

@@ -0,0 +1,350 @@
''' Towards An End-to-End Framework for Video Inpainting
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.modules.flow_comp import SPyNet
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
from model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
from model.modules.spectral_norm import spectral_norm as _spectral_norm
class BaseNetwork(nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
def print_network(self):
if isinstance(self, list):
self = self[0]
num_params = 0
for param in self.parameters():
num_params += param.numel()
print(
'Network [%s] was created. Total number of parameters: %.1f million. '
'To see the architecture, do print(network).' %
(type(self).__name__, num_params / 1000000))
def init_weights(self, init_type='normal', gain=0.02):
'''
initialize network's weights
init_type: normal | xavier | kaiming | orthogonal
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
'''
def init_func(m):
classname = m.__class__.__name__
if classname.find('InstanceNorm2d') != -1:
if hasattr(m, 'weight') and m.weight is not None:
nn.init.constant_(m.weight.data, 1.0)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
or classname.find('Linear') != -1):
if init_type == 'normal':
nn.init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
nn.init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'xavier_uniform':
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
elif init_type == 'kaiming':
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
nn.init.orthogonal_(m.weight.data, gain=gain)
elif init_type == 'none': # uses pytorch's default init method
m.reset_parameters()
else:
raise NotImplementedError(
'initialization method [%s] is not implemented' %
init_type)
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias.data, 0.0)
self.apply(init_func)
# propagate to children
for m in self.children():
if hasattr(m, 'init_weights'):
m.init_weights(init_type, gain)
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
self.group = [1, 2, 4, 8, 1]
self.layers = nn.ModuleList([
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
nn.LeakyReLU(0.2, inplace=True)
])
def forward(self, x):
bt, c, _, _ = x.size()
# h, w = h//4, w//4
out = x
for i, layer in enumerate(self.layers):
if i == 8:
x0 = out
_, _, h, w = x0.size()
if i > 8 and i % 2 == 0:
g = self.group[(i - 8) // 2]
x = x0.view(bt, g, -1, h, w)
o = out.view(bt, g, -1, h, w)
out = torch.cat([x, o], 2).view(bt, -1, h, w)
out = layer(out)
return out
class deconv(nn.Module):
def __init__(self,
input_channel,
output_channel,
kernel_size=3,
padding=0):
super().__init__()
self.conv = nn.Conv2d(input_channel,
output_channel,
kernel_size=kernel_size,
stride=1,
padding=padding)
def forward(self, x):
x = F.interpolate(x,
scale_factor=2,
mode='bilinear',
align_corners=True)
return self.conv(x)
class InpaintGenerator(BaseNetwork):
def __init__(self, init_weights=True):
super(InpaintGenerator, self).__init__()
channel = 256
hidden = 512
# encoder
self.encoder = Encoder()
# decoder
self.decoder = nn.Sequential(
deconv(channel // 2, 128, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2, inplace=True),
deconv(64, 64, kernel_size=3, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
# feature propagation module
self.feat_prop_module = BidirectionalPropagation(channel // 2)
# soft split and soft composition
kernel_size = (7, 7)
padding = (3, 3)
stride = (3, 3)
output_size = (60, 108)
t2t_params = {
'kernel_size': kernel_size,
'stride': stride,
'padding': padding
}
self.ss = SoftSplit(channel // 2,
hidden,
kernel_size,
stride,
padding,
t2t_param=t2t_params)
self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
n_vecs = 1
for i, d in enumerate(kernel_size):
n_vecs *= int((output_size[i] + 2 * padding[i] -
(d - 1) - 1) / stride[i] + 1)
blocks = []
depths = 8
num_heads = [4] * depths
window_size = [(5, 9)] * depths
focal_windows = [(5, 9)] * depths
focal_levels = [2] * depths
pool_method = "fc"
for i in range(depths):
blocks.append(
TemporalFocalTransformerBlock(dim=hidden,
num_heads=num_heads[i],
window_size=window_size[i],
focal_level=focal_levels[i],
focal_window=focal_windows[i],
n_vecs=n_vecs,
t2t_params=t2t_params,
pool_method=pool_method))
self.transformer = nn.Sequential(*blocks)
if init_weights:
self.init_weights()
# Need to initial the weights of MSDeformAttn specifically
for m in self.modules():
if isinstance(m, SecondOrderDeformableAlignment):
m.init_offset()
# flow completion network
self.update_spynet = SPyNet()
def forward_bidirect_flow(self, masked_local_frames):
b, l_t, c, h, w = masked_local_frames.size()
# compute forward and backward flows of masked frames
masked_local_frames = F.interpolate(masked_local_frames.view(
-1, c, h, w),
scale_factor=1 / 4,
mode='bilinear',
align_corners=True,
recompute_scale_factor=True)
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
w // 4)
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
-1, c, h // 4, w // 4)
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
-1, c, h // 4, w // 4)
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
w // 4)
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
w // 4)
return pred_flows_forward, pred_flows_backward
def forward(self, masked_frames, num_local_frames):
l_t = num_local_frames
b, t, ori_c, ori_h, ori_w = masked_frames.size()
# normalization before feeding into the flow completion module
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
pred_flows = self.forward_bidirect_flow(masked_local_frames)
# extracting features and performing the feature propagation on local features
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
_, c, h, w = enc_feat.size()
fold_output_size = (h, w)
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
pred_flows[1])
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
# content hallucination through stacking multiple temporal focal transformer blocks
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
trans_feat = self.transformer([trans_feat, fold_output_size])
trans_feat = self.sc(trans_feat[0], t, fold_output_size)
trans_feat = trans_feat.view(b, t, -1, h, w)
enc_feat = enc_feat + trans_feat
# decode frames from features
output = self.decoder(enc_feat.view(b * t, c, h, w))
output = torch.tanh(output)
return output, pred_flows
# ######################################################################
# Discriminator for Temporal Patch GAN
# ######################################################################
class Discriminator(BaseNetwork):
def __init__(self,
in_channels=3,
use_sigmoid=False,
use_spectral_norm=True,
init_weights=True):
super(Discriminator, self).__init__()
self.use_sigmoid = use_sigmoid
nf = 32
self.conv = nn.Sequential(
spectral_norm(
nn.Conv3d(in_channels=in_channels,
out_channels=nf * 1,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=1,
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(64, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 1,
nf * 2,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(128, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 2,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
spectral_norm(
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2),
bias=not use_spectral_norm), use_spectral_norm),
# nn.InstanceNorm2d(256, track_running_stats=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv3d(nf * 4,
nf * 4,
kernel_size=(3, 5, 5),
stride=(1, 2, 2),
padding=(1, 2, 2)))
if init_weights:
self.init_weights()
def forward(self, xs):
# T, C, H, W = xs.shape (old)
# B, T, C, H, W (new)
xs_t = torch.transpose(xs, 1, 2)
feat = self.conv(xs_t)
if self.use_sigmoid:
feat = torch.sigmoid(feat)
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
return out
def spectral_norm(module, mode=True):
if mode:
return _spectral_norm(module)
return module

View File

@@ -0,0 +1,149 @@
"""
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
"""
import torch
import torch.nn as nn
from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
from mmengine.model import constant_init
from model.modules.flow_comp import flow_warp
class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
"""Second-order deformable alignment module."""
def __init__(self, *args, **kwargs):
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
self.conv_offset = nn.Sequential(
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
)
self.init_offset()
def init_offset(self):
constant_init(self.conv_offset[-1], val=0, bias=0)
def forward(self, x, extra_feat, flow_1, flow_2):
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
out = self.conv_offset(extra_feat)
o1, o2, mask = torch.chunk(out, 3, dim=1)
# offset
offset = self.max_residue_magnitude * torch.tanh(
torch.cat((o1, o2), dim=1))
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
offset_1 = offset_1 + flow_1.flip(1).repeat(1,
offset_1.size(1) // 2, 1,
1)
offset_2 = offset_2 + flow_2.flip(1).repeat(1,
offset_2.size(1) // 2, 1,
1)
offset = torch.cat([offset_1, offset_2], dim=1)
# mask
mask = torch.sigmoid(mask)
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
self.stride, self.padding,
self.dilation, self.groups,
self.deform_groups)
class BidirectionalPropagation(nn.Module):
def __init__(self, channel):
super(BidirectionalPropagation, self).__init__()
modules = ['backward_', 'forward_']
self.deform_align = nn.ModuleDict()
self.backbone = nn.ModuleDict()
self.channel = channel
for i, module in enumerate(modules):
self.deform_align[module] = SecondOrderDeformableAlignment(
2 * channel, channel, 3, padding=1, deform_groups=16)
self.backbone[module] = nn.Sequential(
nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
nn.LeakyReLU(negative_slope=0.1, inplace=True),
nn.Conv2d(channel, channel, 3, 1, 1),
)
self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
def forward(self, x, flows_backward, flows_forward):
"""
x shape : [b, t, c, h, w]
return [b, t, c, h, w]
"""
b, t, c, h, w = x.shape
feats = {}
feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
for module_name in ['backward_', 'forward_']:
feats[module_name] = []
frame_idx = range(0, t)
flow_idx = range(-1, t - 1)
mapping_idx = list(range(0, len(feats['spatial'])))
mapping_idx += mapping_idx[::-1]
if 'backward' in module_name:
frame_idx = frame_idx[::-1]
flows = flows_backward
else:
flows = flows_forward
feat_prop = x.new_zeros(b, self.channel, h, w)
for i, idx in enumerate(frame_idx):
feat_current = feats['spatial'][mapping_idx[idx]]
if i > 0:
flow_n1 = flows[:, flow_idx[i], :, :, :]
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
# initialize second-order features
feat_n2 = torch.zeros_like(feat_prop)
flow_n2 = torch.zeros_like(flow_n1)
cond_n2 = torch.zeros_like(cond_n1)
if i > 1:
feat_n2 = feats[module_name][-2]
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
flow_n2 = flow_n1 + flow_warp(
flow_n2, flow_n1.permute(0, 2, 3, 1))
cond_n2 = flow_warp(feat_n2,
flow_n2.permute(0, 2, 3, 1))
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
feat_prop = self.deform_align[module_name](feat_prop, cond,
flow_n1,
flow_n2)
feat = [feat_current] + [
feats[k][idx]
for k in feats if k not in ['spatial', module_name]
] + [feat_prop]
feat = torch.cat(feat, dim=1)
feat_prop = feat_prop + self.backbone[module_name](feat)
feats[module_name].append(feat_prop)
if 'backward' in module_name:
feats[module_name] = feats[module_name][::-1]
outputs = []
for i in range(0, t):
align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
align_feats = torch.cat(align_feats, dim=1)
outputs.append(self.fusion(align_feats))
return torch.stack(outputs, dim=1) + x

View File

@@ -0,0 +1,450 @@
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
from mmcv.cnn import ConvModule
from mmengine.runner import load_checkpoint
class FlowCompletionLoss(nn.Module):
"""Flow completion loss"""
def __init__(self):
super().__init__()
self.fix_spynet = SPyNet()
for p in self.fix_spynet.parameters():
p.requires_grad = False
self.l1_criterion = nn.L1Loss()
def forward(self, pred_flows, gt_local_frames):
b, l_t, c, h, w = gt_local_frames.size()
with torch.no_grad():
# compute gt forward and backward flows
gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
scale_factor=1 / 4,
mode='bilinear',
align_corners=True,
recompute_scale_factor=True)
gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
-1, c, h // 4, w // 4)
gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
-1, c, h // 4, w // 4)
gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
# calculate loss for flow completion
forward_flow_loss = self.l1_criterion(
pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
backward_flow_loss = self.l1_criterion(
pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
flow_loss = forward_flow_loss + backward_flow_loss
return flow_loss
class SPyNet(nn.Module):
"""SPyNet network structure.
The difference to the SPyNet in [tof.py] is that
1. more SPyNetBasicModule is used in this version, and
2. no batch normalization is used in this version.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
Args:
pretrained (str): path for pre-trained SPyNet. Default: None.
"""
def __init__(
self,
use_pretrain=True,
pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
):
super().__init__()
self.basic_module = nn.ModuleList(
[SPyNetBasicModule() for _ in range(6)])
if use_pretrain:
if isinstance(pretrained, str):
print("load pretrained SPyNet...")
load_checkpoint(self, pretrained, strict=True)
elif pretrained is not None:
raise TypeError('[pretrained] should be str or None, '
f'but got {type(pretrained)}.')
self.register_buffer(
'mean',
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer(
'std',
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
def compute_flow(self, ref, supp):
"""Compute flow from ref to supp.
Note that in this function, the images are already resized to a
multiple of 32.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
n, _, h, w = ref.size()
# normalize the input images
ref = [(ref - self.mean) / self.std]
supp = [(supp - self.mean) / self.std]
# generate downsampled frames
for level in range(5):
ref.append(
F.avg_pool2d(input=ref[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
supp.append(
F.avg_pool2d(input=supp[-1],
kernel_size=2,
stride=2,
count_include_pad=False))
ref = ref[::-1]
supp = supp[::-1]
# flow computation
flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
for level in range(len(ref)):
if level == 0:
flow_up = flow
else:
flow_up = F.interpolate(input=flow,
scale_factor=2,
mode='bilinear',
align_corners=True) * 2.0
# add the residue to the upsampled flow
flow = flow_up + self.basic_module[level](torch.cat([
ref[level],
flow_warp(supp[level],
flow_up.permute(0, 2, 3, 1).contiguous(),
padding_mode='border'), flow_up
], 1))
return flow
def forward(self, ref, supp):
"""Forward function of SPyNet.
This function computes the optical flow from ref to supp.
Args:
ref (Tensor): Reference image with shape of (n, 3, h, w).
supp (Tensor): Supporting image with shape of (n, 3, h, w).
Returns:
Tensor: Estimated optical flow: (n, 2, h, w).
"""
# upsize to a multiple of 32
h, w = ref.shape[2:4]
w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
ref = F.interpolate(input=ref,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)
supp = F.interpolate(input=supp,
size=(h_up, w_up),
mode='bilinear',
align_corners=False)
# compute flow, and resize back to the original resolution
flow = F.interpolate(input=self.compute_flow(ref, supp),
size=(h, w),
mode='bilinear',
align_corners=False)
# adjust the flow values
flow[:, 0, :, :] *= float(w) / float(w_up)
flow[:, 1, :, :] *= float(h) / float(h_up)
return flow
class SPyNetBasicModule(nn.Module):
"""Basic Module for SPyNet.
Paper:
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
"""
def __init__(self):
super().__init__()
self.basic_module = nn.Sequential(
ConvModule(in_channels=8,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(in_channels=32,
out_channels=64,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(in_channels=64,
out_channels=32,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(in_channels=32,
out_channels=16,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=dict(type='ReLU')),
ConvModule(in_channels=16,
out_channels=2,
kernel_size=7,
stride=1,
padding=3,
norm_cfg=None,
act_cfg=None))
def forward(self, tensor_input):
"""
Args:
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
8 channels contain:
[reference image (3), neighbor image (3), initial flow (2)].
Returns:
Tensor: Refined flow with shape (b, 2, h, w)
"""
return self.basic_module(tensor_input)
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
def make_colorwheel():
"""
Generates a color wheel for optical flow visualization as presented in:
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
Code follows the original C++ source code of Daniel Scharstein.
Code follows the the Matlab source code of Deqing Sun.
Returns:
np.ndarray: Color wheel
"""
RY = 15
YG = 6
GC = 4
CB = 11
BM = 13
MR = 6
ncols = RY + YG + GC + CB + BM + MR
colorwheel = np.zeros((ncols, 3))
col = 0
# RY
colorwheel[0:RY, 0] = 255
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
col = col + RY
# YG
colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
colorwheel[col:col + YG, 1] = 255
col = col + YG
# GC
colorwheel[col:col + GC, 1] = 255
colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
col = col + GC
# CB
colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
colorwheel[col:col + CB, 2] = 255
col = col + CB
# BM
colorwheel[col:col + BM, 2] = 255
colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
col = col + BM
# MR
colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
colorwheel[col:col + MR, 0] = 255
return colorwheel
def flow_uv_to_colors(u, v, convert_to_bgr=False):
"""
Applies the flow color wheel to (possibly clipped) flow components u and v.
According to the C++ source code of Daniel Scharstein
According to the Matlab source code of Deqing Sun
Args:
u (np.ndarray): Input horizontal flow of shape [H,W]
v (np.ndarray): Input vertical flow of shape [H,W]
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
colorwheel = make_colorwheel() # shape [55x3]
ncols = colorwheel.shape[0]
rad = np.sqrt(np.square(u) + np.square(v))
a = np.arctan2(-v, -u) / np.pi
fk = (a + 1) / 2 * (ncols - 1)
k0 = np.floor(fk).astype(np.int32)
k1 = k0 + 1
k1[k1 == ncols] = 0
f = fk - k0
for i in range(colorwheel.shape[1]):
tmp = colorwheel[:, i]
col0 = tmp[k0] / 255.0
col1 = tmp[k1] / 255.0
col = (1 - f) * col0 + f * col1
idx = (rad <= 1)
col[idx] = 1 - rad[idx] * (1 - col[idx])
col[~idx] = col[~idx] * 0.75 # out of range
# Note the 2-i => BGR instead of RGB
ch_idx = 2 - i if convert_to_bgr else i
flow_image[:, :, ch_idx] = np.floor(255 * col)
return flow_image
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
"""
Expects a two dimensional flow image of shape.
Args:
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
Returns:
np.ndarray: Flow visualization image of shape [H,W,3]
"""
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
if clip_flow is not None:
flow_uv = np.clip(flow_uv, 0, clip_flow)
u = flow_uv[:, :, 0]
v = flow_uv[:, :, 1]
rad = np.sqrt(np.square(u) + np.square(v))
rad_max = np.max(rad)
epsilon = 1e-5
u = u / (rad_max + epsilon)
v = v / (rad_max + epsilon)
return flow_uv_to_colors(u, v, convert_to_bgr)
def flow_warp(x,
flow,
interpolation='bilinear',
padding_mode='zeros',
align_corners=True):
"""Warp an image or a feature map with optical flow.
Args:
x (Tensor): Tensor with size (n, c, h, w).
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
a two-channel, denoting the width and height relative offsets.
Note that the values are not normalized to [-1, 1].
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
Default: 'bilinear'.
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
Default: 'zeros'.
align_corners (bool): Whether align corners. Default: True.
Returns:
Tensor: Warped image or feature map.
"""
if x.size()[-2:] != flow.size()[1:3]:
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
f'flow ({flow.size()[1:3]}) are not the same.')
_, _, h, w = x.size()
# create mesh grid
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
grid.requires_grad = False
grid_flow = grid + flow
# scale grid_flow to [-1,1]
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
output = F.grid_sample(x,
grid_flow,
mode=interpolation,
padding_mode=padding_mode,
align_corners=align_corners)
return output
def initial_mask_flow(mask):
"""
mask 1 indicates valid pixel 0 indicates unknown pixel
"""
B, T, C, H, W = mask.shape
# calculate relative position
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
# calculate the nearest indices
pos_up = mask.unsqueeze(3).repeat(
1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
relative_pos_y <= H)[None, None, None]
nearest_indice_up = pos_up.max(dim=4)[1]
pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
None, None, None] * (relative_pos_y <= H)[None, None, None]
nearest_indice_down = (pos_down).max(dim=4)[1]
pos_left = mask.unsqueeze(4).repeat(
1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
relative_pos_x <= W)[None, None, None]
nearest_indice_left = (pos_left).max(dim=5)[1]
pos_right = mask.unsqueeze(4).repeat(
1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
relative_pos_x <= W)[None, None, None]
nearest_indice_right = (pos_right).max(dim=5)[1]
# NOTE: IMPORTANT !!! depending on how to use this offset
initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
initial_offset_down = nearest_indice_down - grid_y[None, None, None]
initial_offset_left = -(nearest_indice_left -
grid_x[None, None, None]).flip(4)
initial_offset_right = nearest_indice_right - grid_x[None, None, None]
# nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
# initial_offset_x = nearest_indice_x - grid_x
# handle the boundary cases
final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
initial_offset_down > 0) * initial_offset_down
final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
initial_offset_up < 0) * initial_offset_up
final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
initial_offset_right > 0) * initial_offset_right
final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
initial_offset_left < 0) * initial_offset_left
zero_offset = torch.zeros_like(final_offset_down)
# out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
out = torch.cat([
zero_offset, final_offset_left, zero_offset, final_offset_right,
final_offset_up, zero_offset, final_offset_down, zero_offset
],
dim=2)
return out

View File

@@ -0,0 +1,288 @@
"""
Spectral Normalization from https://arxiv.org/abs/1802.05957
"""
import torch
from torch.nn.functional import normalize
class SpectralNorm(object):
# Invariant before and after each forward call:
# u = normalize(W @ v)
# NB: At initialization, this invariant is not enforced
_version = 1
# At version 1:
# made `W` not a buffer,
# added `v` as a buffer, and
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
self.name = name
self.dim = dim
if n_power_iterations <= 0:
raise ValueError(
'Expected n_power_iterations to be positive, but '
'got n_power_iterations={}'.format(n_power_iterations))
self.n_power_iterations = n_power_iterations
self.eps = eps
def reshape_weight_to_matrix(self, weight):
weight_mat = weight
if self.dim != 0:
# permute dim to front
weight_mat = weight_mat.permute(
self.dim,
*[d for d in range(weight_mat.dim()) if d != self.dim])
height = weight_mat.size(0)
return weight_mat.reshape(height, -1)
def compute_weight(self, module, do_power_iteration):
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
# updated in power iteration **in-place**. This is very important
# because in `DataParallel` forward, the vectors (being buffers) are
# broadcast from the parallelized module to each module replica,
# which is a new module object created on the fly. And each replica
# runs its own spectral norm power iteration. So simply assigning
# the updated vectors to the module this function runs on will cause
# the update to be lost forever. And the next time the parallelized
# module is replicated, the same randomly initialized vectors are
# broadcast and used!
#
# Therefore, to make the change propagate back, we rely on two
# important behaviors (also enforced via tests):
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
# is already on correct device; and it makes sure that the
# parallelized module is already on `device[0]`.
# 2. If the out tensor in `out=` kwarg has correct shape, it will
# just fill in the values.
# Therefore, since the same power iteration is performed on all
# devices, simply updating the tensors in-place will make sure that
# the module replica on `device[0]` will update the _u vector on the
# parallized module (by shared storage).
#
# However, after we update `u` and `v` in-place, we need to **clone**
# them before using them to normalize the weight. This is to support
# backproping through two forward passes, e.g., the common pattern in
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
# complain that variables needed to do backward for the first forward
# (i.e., the `u` and `v` vectors) are changed in the second forward.
weight = getattr(module, self.name + '_orig')
u = getattr(module, self.name + '_u')
v = getattr(module, self.name + '_v')
weight_mat = self.reshape_weight_to_matrix(weight)
if do_power_iteration:
with torch.no_grad():
for _ in range(self.n_power_iterations):
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
# are the first left and right singular vectors.
# This power iteration produces approximations of `u` and `v`.
v = normalize(torch.mv(weight_mat.t(), u),
dim=0,
eps=self.eps,
out=v)
u = normalize(torch.mv(weight_mat, v),
dim=0,
eps=self.eps,
out=u)
if self.n_power_iterations > 0:
# See above on why we need to clone
u = u.clone()
v = v.clone()
sigma = torch.dot(u, torch.mv(weight_mat, v))
weight = weight / sigma
return weight
def remove(self, module):
with torch.no_grad():
weight = self.compute_weight(module, do_power_iteration=False)
delattr(module, self.name)
delattr(module, self.name + '_u')
delattr(module, self.name + '_v')
delattr(module, self.name + '_orig')
module.register_parameter(self.name,
torch.nn.Parameter(weight.detach()))
def __call__(self, module, inputs):
setattr(
module, self.name,
self.compute_weight(module, do_power_iteration=module.training))
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
# (the invariant at top of this class) and `u @ W @ v = sigma`.
# This uses pinverse in case W^T W is not invertible.
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
weight_mat.t(), u.unsqueeze(1)).squeeze(1)
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
@staticmethod
def apply(module, name, n_power_iterations, dim, eps):
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
raise RuntimeError(
"Cannot register two spectral_norm hooks on "
"the same parameter {}".format(name))
fn = SpectralNorm(name, n_power_iterations, dim, eps)
weight = module._parameters[name]
with torch.no_grad():
weight_mat = fn.reshape_weight_to_matrix(weight)
h, w = weight_mat.size()
# randomly initialize `u` and `v`
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
delattr(module, fn.name)
module.register_parameter(fn.name + "_orig", weight)
# We still need to assign weight back as fn.name because all sorts of
# things may assume that it exists, e.g., when initializing weights.
# However, we can't directly assign as it could be an nn.Parameter and
# gets added as a parameter. Instead, we register weight.data as a plain
# attribute.
setattr(module, fn.name, weight.data)
module.register_buffer(fn.name + "_u", u)
module.register_buffer(fn.name + "_v", v)
module.register_forward_pre_hook(fn)
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
module._register_load_state_dict_pre_hook(
SpectralNormLoadStateDictPreHook(fn))
return fn
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormLoadStateDictPreHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
# For state_dict with version None, (assuming that it has gone through at
# least one training forward), we have
#
# u = normalize(W_orig @ v)
# W = W_orig / sigma, where sigma = u @ W_orig @ v
#
# To compute `v`, we solve `W_orig @ x = u`, and let
# v = x / (u @ W_orig @ x) * (W / W_orig).
def __call__(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
fn = self.fn
version = local_metadata.get('spectral_norm',
{}).get(fn.name + '.version', None)
if version is None or version < 1:
with torch.no_grad():
weight_orig = state_dict[prefix + fn.name + '_orig']
# weight = state_dict.pop(prefix + fn.name)
# sigma = (weight_orig / weight).mean()
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
u = state_dict[prefix + fn.name + '_u']
# v = fn._solve_v_and_rescale(weight_mat, u, sigma)
# state_dict[prefix + fn.name + '_v'] = v
# This is a top level class because Py2 pickle doesn't like inner class nor an
# instancemethod.
class SpectralNormStateDictHook(object):
# See docstring of SpectralNorm._version on the changes to spectral_norm.
def __init__(self, fn):
self.fn = fn
def __call__(self, module, state_dict, prefix, local_metadata):
if 'spectral_norm' not in local_metadata:
local_metadata['spectral_norm'] = {}
key = self.fn.name + '.version'
if key in local_metadata['spectral_norm']:
raise RuntimeError(
"Unexpected key in metadata['spectral_norm']: {}".format(key))
local_metadata['spectral_norm'][key] = self.fn._version
def spectral_norm(module,
name='weight',
n_power_iterations=1,
eps=1e-12,
dim=None):
r"""Applies spectral normalization to a parameter in the given module.
.. math::
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
Spectral normalization stabilizes the training of discriminators (critics)
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
with spectral norm :math:`\sigma` of the weight matrix calculated using
power iteration method. If the dimension of the weight tensor is greater
than 2, it is reshaped to 2D in power iteration method to get spectral
norm. This is implemented via a hook that calculates spectral norm and
rescales weight before every :meth:`~Module.forward` call.
See `Spectral Normalization for Generative Adversarial Networks`_ .
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
Args:
module (nn.Module): containing module
name (str, optional): name of weight parameter
n_power_iterations (int, optional): number of power iterations to
calculate spectral norm
eps (float, optional): epsilon for numerical stability in
calculating norms
dim (int, optional): dimension corresponding to number of outputs,
the default is ``0``, except for modules that are instances of
ConvTranspose{1,2,3}d, when it is ``1``
Returns:
The original module with the spectral norm hook
Example::
>>> m = spectral_norm(nn.Linear(20, 40))
>>> m
Linear(in_features=20, out_features=40, bias=True)
>>> m.weight_u.size()
torch.Size([40])
"""
if dim is None:
if isinstance(module,
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
torch.nn.ConvTranspose3d)):
dim = 1
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
return module
def remove_spectral_norm(module, name='weight'):
r"""Removes the spectral normalization reparameterization from a module.
Args:
module (Module): containing module
name (str, optional): name of weight parameter
Example:
>>> m = spectral_norm(nn.Linear(40, 10))
>>> remove_spectral_norm(m)
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, SpectralNorm) and hook.name == name:
hook.remove(module)
del module._forward_pre_hooks[k]
return module
raise ValueError("spectral_norm of '{}' not found in {}".format(
name, module))
def use_spectral_norm(module, use_sn=False):
if use_sn:
return spectral_norm(module)
return module

View File

@@ -0,0 +1,536 @@
"""
This code is based on:
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
https://github.com/ruiliu-ai/FuseFormer
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
https://github.com/yitu-opensource/T2T-ViT
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
https://github.com/microsoft/Focal-Transformer
"""
import math
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftSplit(nn.Module):
def __init__(self, channel, hidden, kernel_size, stride, padding,
t2t_param):
super(SoftSplit, self).__init__()
self.kernel_size = kernel_size
self.t2t = nn.Unfold(kernel_size=kernel_size,
stride=stride,
padding=padding)
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(c_in, hidden)
self.f_h = int(
(t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
(t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
1)
self.f_w = int(
(t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
(t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
1)
def forward(self, x, b):
feat = self.t2t(x)
feat = feat.permute(0, 2, 1)
# feat shape [b*t, num_vec, ks*ks*c]
feat = self.embedding(feat)
# feat shape after embedding [b, t*num_vec, hidden]
feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
return feat
class SoftComp(nn.Module):
def __init__(self, channel, hidden, output_size, kernel_size, stride,
padding):
super(SoftComp, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.t2t = torch.nn.Fold(output_size=output_size,
kernel_size=kernel_size,
stride=stride,
padding=padding)
h, w = output_size
self.bias = nn.Parameter(torch.zeros((channel, h, w),
dtype=torch.float32),
requires_grad=True)
def forward(self, x, t):
b_, _, _, _, c_ = x.shape
x = x.view(b_, -1, c_)
feat = self.embedding(x)
b, _, c = feat.size()
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
feat = self.t2t(feat) + self.bias[None]
return feat
class FusionFeedForward(nn.Module):
def __init__(self, d_model, n_vecs=None, t2t_params=None):
super(FusionFeedForward, self).__init__()
# We set d_ff as a default to 1960
hd = 1960
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
assert t2t_params is not None and n_vecs is not None
tp = t2t_params.copy()
self.fold = nn.Fold(**tp)
del tp['output_size']
self.unfold = nn.Unfold(**tp)
self.n_vecs = n_vecs
def forward(self, x):
x = self.conv1(x)
b, n, c = x.size()
normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
49).permute(0, 2, 1)
x = self.unfold(
self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
x = self.conv2(x)
return x
def window_partition(x, window_size):
"""
Args:
x: shape is (B, T, H, W, C)
window_size (tuple[int]): window size
Returns:
windows: (B*num_windows, T*window_size*window_size, C)
"""
B, T, H, W, C = x.shape
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
window_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
-1, T * window_size[0] * window_size[1], C)
return windows
def window_partition_noreshape(x, window_size):
"""
Args:
x: shape is (B, T, H, W, C)
window_size (tuple[int]): window size
Returns:
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
"""
B, T, H, W, C = x.shape
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
window_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
return windows
def window_reverse(windows, window_size, T, H, W):
"""
Args:
windows: shape is (num_windows*B, T, window_size, window_size, C)
window_size (tuple[int]): Window size
T (int): Temporal length of video
H (int): Height of image
W (int): Width of image
Returns:
x: (B, T, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], T,
window_size[0], window_size[1], -1)
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Temporal focal window attention
"""
def __init__(self, dim, expand_size, window_size, focal_window,
focal_level, num_heads, qkv_bias, pool_method):
super().__init__()
self.dim = dim
self.expand_size = expand_size
self.window_size = window_size # Wh, Ww
self.pool_method = pool_method
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.focal_level = focal_level
self.focal_window = focal_window
if any(i > 0 for i in self.expand_size) and focal_level > 0:
# get mask for rolled k and rolled v
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
mask_br = torch.ones(self.window_size[0], self.window_size[1])
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
0).flatten(0)
self.register_buffer("valid_ind_rolled",
mask_rolled.nonzero(as_tuple=False).view(-1))
if pool_method != "none" and focal_level > 1:
self.unfolds = nn.ModuleList()
# build relative position bias between local patch and pooled windows
for k in range(focal_level - 1):
stride = 2**k
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
for i in self.focal_window)
# define unfolding operations
self.unfolds += [
nn.Unfold(kernel_size=kernel_size,
stride=stride,
padding=tuple(i // 2 for i in kernel_size))
]
# define unfolding index for focal_level > 0
if k > 0:
mask = torch.zeros(kernel_size)
mask[(2**k) - 1:, (2**k) - 1:] = 1
self.register_buffer(
"valid_ind_unfold_{}".format(k),
mask.flatten(0).nonzero(as_tuple=False).view(-1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x_all, mask_all=None):
"""
Args:
x: input features with shape of (B, T, Wh, Ww, C)
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
output: (nW*B, Wh*Ww, C)
"""
x = x_all[0]
B, T, nH, nW, C = x.shape
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
C).permute(4, 0, 1, 2, 3, 5).contiguous()
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
# partition q map
(q_windows, k_windows, v_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
contiguous().view(-1, self.num_heads, T * self.window_size[
0] * self.window_size[1], C // self.num_heads), (q, k, v))
# q(k/v)_windows shape : [16, 4, 225, 128]
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
(k_tl, v_tl) = map(
lambda t: torch.roll(t,
shifts=(-self.expand_size[0], -self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_tr, v_tr) = map(
lambda t: torch.roll(t,
shifts=(-self.expand_size[0], self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_bl, v_bl) = map(
lambda t: torch.roll(t,
shifts=(self.expand_size[0], -self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_br, v_br) = map(
lambda t: torch.roll(t,
shifts=(self.expand_size[0], self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
k_rolled = torch.cat(
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
2).permute(0, 3, 1, 2, 4).contiguous()
v_rolled = torch.cat(
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
2).permute(0, 3, 1, 2, 4).contiguous()
# mask out tokens in current window
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
temp_N = k_rolled.shape[3]
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
C // self.num_heads)
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
C // self.num_heads)
k_rolled = torch.cat((k_windows, k_rolled), 2)
v_rolled = torch.cat((v_windows, v_rolled), 2)
else:
k_rolled = k_windows
v_rolled = v_windows
# q(k/v)_windows shape : [16, 4, 225, 128]
# k_rolled.shape : [16, 4, 5, 165, 128]
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
if self.pool_method != "none" and self.focal_level > 1:
k_pooled = []
v_pooled = []
for k in range(self.focal_level - 1):
stride = 2**k
x_window_pooled = x_all[k + 1].permute(
0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
nWh, nWw = x_window_pooled.shape[2:4]
# generate mask for pooled windows
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
# unfold mask: [nWh*nWw//s//s, k*k, 1]
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
view(nWh*nWw // stride // stride, -1, 1)
if k > 0:
valid_ind_unfold_k = getattr(
self, "valid_ind_unfold_{}".format(k))
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
x_window_masks = x_window_masks.masked_fill(
x_window_masks == 0,
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
mask_all[k + 1] = x_window_masks
# generate k and v for pooled windows
qkv_pooled = self.qkv(x_window_pooled).reshape(
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
3).view(3, -1, C, nWh,
nWw).contiguous()
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
2] # B*T, C, nWh, nWw
# k_pooled_k shape: [5, 512, 4, 4]
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
(k_pooled_k, v_pooled_k) = map(
lambda t: self.unfolds[k](t).view(
B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
(k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
)
# k_pooled_k shape : [16, 4, 5, 45, 128]
# select valid unfolding index
if k > 0:
(k_pooled_k, v_pooled_k) = map(
lambda t: t[:, :, :, valid_ind_unfold_k],
(k_pooled_k, v_pooled_k))
k_pooled_k = k_pooled_k.view(
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
self.unfolds[k].kernel_size[1], C // self.num_heads)
v_pooled_k = v_pooled_k.view(
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
self.unfolds[k].kernel_size[1], C // self.num_heads)
k_pooled += [k_pooled_k]
v_pooled += [v_pooled_k]
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
k_all = torch.cat([k_rolled] + k_pooled, 2)
v_all = torch.cat([v_rolled] + v_pooled, 2)
else:
k_all = k_rolled
v_all = v_rolled
N = k_all.shape[-2]
q_windows = q_windows * self.scale
attn = (
q_windows @ k_all.transpose(-2, -1)
) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
# T * 45
window_area = T * self.window_size[0] * self.window_size[1]
# T * 165
window_area_rolled = k_rolled.shape[2]
if self.pool_method != "none" and self.focal_level > 1:
offset = window_area_rolled
for k in range(self.focal_level - 1):
# add attentional mask
# mask_all[1] shape [1, 16, T * 45]
bias = tuple((i + 2**k - 1) for i in self.focal_window)
if mask_all[k + 1] is not None:
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
offset += T * bias[0] * bias[1]
if mask_all[0] is not None:
nW = mask_all[0].shape[0]
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
window_area, N)
attn[:, :, :, :, :
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
None, :, None, :, :]
attn = attn.view(-1, self.num_heads, window_area, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
C)
x = self.proj(x)
return x
class TemporalFocalTransformerBlock(nn.Module):
r""" Temporal Focal Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): The number level of focal window.
focal_window (int): Window size of each focal window.
n_vecs (int): Required for F3N.
t2t_params (int): T2T parameters for F3N.
"""
def __init__(self,
dim,
num_heads,
window_size=(5, 9),
mlp_ratio=4.,
qkv_bias=True,
pool_method="fc",
focal_level=2,
focal_window=(5, 9),
norm_layer=nn.LayerNorm,
n_vecs=None,
t2t_params=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.expand_size = tuple(i // 2 for i in window_size) # TODO
self.mlp_ratio = mlp_ratio
self.pool_method = pool_method
self.focal_level = focal_level
self.focal_window = focal_window
self.window_size_glo = self.window_size
self.pool_layers = nn.ModuleList()
if self.pool_method != "none":
for k in range(self.focal_level - 1):
window_size_glo = tuple(
math.floor(i / (2**k)) for i in self.window_size_glo)
self.pool_layers.append(
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
self.pool_layers[-1].weight.data.fill_(
1. / (window_size_glo[0] * window_size_glo[1]))
self.pool_layers[-1].bias.data.fill_(0)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(dim,
expand_size=self.expand_size,
window_size=self.window_size,
focal_window=focal_window,
focal_level=focal_level,
num_heads=num_heads,
qkv_bias=qkv_bias,
pool_method=pool_method)
self.norm2 = norm_layer(dim)
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
def forward(self, x):
B, T, H, W, C = x.shape
shortcut = x
x = self.norm1(x)
shifted_x = x
x_windows_all = [shifted_x]
x_window_masks_all = [None]
# partition windows tuple(i // 2 for i in window_size)
if self.focal_level > 1 and self.pool_method != "none":
# if we add coarser granularity and the pool method is not none
for k in range(self.focal_level - 1):
window_size_glo = tuple(
math.floor(i / (2**k)) for i in self.window_size_glo)
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
H_pool = pooled_h * window_size_glo[0]
W_pool = pooled_w * window_size_glo[1]
x_level_k = shifted_x
# trim or pad shifted_x depending on the required size
if H > H_pool:
trim_t = (H - H_pool) // 2
trim_b = H - H_pool - trim_t
x_level_k = x_level_k[:, :, trim_t:-trim_b]
elif H < H_pool:
pad_t = (H_pool - H) // 2
pad_b = H_pool - H - pad_t
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
if W > W_pool:
trim_l = (W - W_pool) // 2
trim_r = W - W_pool - trim_l
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
elif W < W_pool:
pad_l = (W_pool - W) // 2
pad_r = W_pool - W - pad_l
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
x_windows_noreshape = window_partition_noreshape(
x_level_k.contiguous(), window_size_glo
) # B, nw, nw, T, window_size, window_size, C
nWh, nWw = x_windows_noreshape.shape[1:3]
x_windows_noreshape = x_windows_noreshape.view(
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
x_windows_pooled = self.pool_layers[k](
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
x_windows_all += [x_windows_pooled]
x_window_masks_all += [None]
attn_windows = self.attn(
x_windows_all,
mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
# merge windows
attn_windows = attn_windows.view(-1, T, self.window_size[0],
self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
W) # B T H' W' C
# FFN
x = shortcut + shifted_x
y = self.norm2(x)
x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
return x

View File

@@ -0,0 +1,565 @@
"""
This code is based on:
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
https://github.com/ruiliu-ai/FuseFormer
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
https://github.com/yitu-opensource/T2T-ViT
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
https://github.com/microsoft/Focal-Transformer
"""
import math
from functools import reduce
import torch
import torch.nn as nn
import torch.nn.functional as F
class SoftSplit(nn.Module):
def __init__(self, channel, hidden, kernel_size, stride, padding,
t2t_param):
super(SoftSplit, self).__init__()
self.kernel_size = kernel_size
self.t2t = nn.Unfold(kernel_size=kernel_size,
stride=stride,
padding=padding)
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(c_in, hidden)
self.t2t_param = t2t_param
def forward(self, x, b, output_size):
f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
(self.t2t_param['kernel_size'][0] - 1) - 1) /
self.t2t_param['stride'][0] + 1)
f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
(self.t2t_param['kernel_size'][1] - 1) - 1) /
self.t2t_param['stride'][1] + 1)
feat = self.t2t(x)
feat = feat.permute(0, 2, 1)
# feat shape [b*t, num_vec, ks*ks*c]
feat = self.embedding(feat)
# feat shape after embedding [b, t*num_vec, hidden]
feat = feat.view(b, -1, f_h, f_w, feat.size(2))
return feat
class SoftComp(nn.Module):
def __init__(self, channel, hidden, kernel_size, stride, padding):
super(SoftComp, self).__init__()
self.relu = nn.LeakyReLU(0.2, inplace=True)
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
self.embedding = nn.Linear(hidden, c_out)
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.bias_conv = nn.Conv2d(channel,
channel,
kernel_size=3,
stride=1,
padding=1)
# TODO upsample conv
# self.bias_conv = nn.Conv2d()
# self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
def forward(self, x, t, output_size):
b_, _, _, _, c_ = x.shape
x = x.view(b_, -1, c_)
feat = self.embedding(x)
b, _, c = feat.size()
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
feat = F.fold(feat,
output_size=output_size,
kernel_size=self.kernel_size,
stride=self.stride,
padding=self.padding)
feat = self.bias_conv(feat)
return feat
class FusionFeedForward(nn.Module):
def __init__(self, d_model, n_vecs=None, t2t_params=None):
super(FusionFeedForward, self).__init__()
# We set d_ff as a default to 1960
hd = 1960
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
assert t2t_params is not None and n_vecs is not None
self.t2t_params = t2t_params
def forward(self, x, output_size):
n_vecs = 1
for i, d in enumerate(self.t2t_params['kernel_size']):
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
x = self.conv1(x)
b, n, c = x.size()
normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
normalizer = F.fold(normalizer,
output_size=output_size,
kernel_size=self.t2t_params['kernel_size'],
padding=self.t2t_params['padding'],
stride=self.t2t_params['stride'])
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
output_size=output_size,
kernel_size=self.t2t_params['kernel_size'],
padding=self.t2t_params['padding'],
stride=self.t2t_params['stride'])
x = F.unfold(x / normalizer,
kernel_size=self.t2t_params['kernel_size'],
padding=self.t2t_params['padding'],
stride=self.t2t_params['stride']).permute(
0, 2, 1).contiguous().view(b, n, c)
x = self.conv2(x)
return x
def window_partition(x, window_size):
"""
Args:
x: shape is (B, T, H, W, C)
window_size (tuple[int]): window size
Returns:
windows: (B*num_windows, T*window_size*window_size, C)
"""
B, T, H, W, C = x.shape
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
window_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
-1, T * window_size[0] * window_size[1], C)
return windows
def window_partition_noreshape(x, window_size):
"""
Args:
x: shape is (B, T, H, W, C)
window_size (tuple[int]): window size
Returns:
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
"""
B, T, H, W, C = x.shape
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
window_size[1], C)
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
return windows
def window_reverse(windows, window_size, T, H, W):
"""
Args:
windows: shape is (num_windows*B, T, window_size, window_size, C)
window_size (tuple[int]): Window size
T (int): Temporal length of video
H (int): Height of image
W (int): Width of image
Returns:
x: (B, T, H, W, C)
"""
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
x = windows.view(B, H // window_size[0], W // window_size[1], T,
window_size[0], window_size[1], -1)
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
return x
class WindowAttention(nn.Module):
"""Temporal focal window attention
"""
def __init__(self, dim, expand_size, window_size, focal_window,
focal_level, num_heads, qkv_bias, pool_method):
super().__init__()
self.dim = dim
self.expand_size = expand_size
self.window_size = window_size # Wh, Ww
self.pool_method = pool_method
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.focal_level = focal_level
self.focal_window = focal_window
if any(i > 0 for i in self.expand_size) and focal_level > 0:
# get mask for rolled k and rolled v
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
mask_br = torch.ones(self.window_size[0], self.window_size[1])
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
0).flatten(0)
self.register_buffer("valid_ind_rolled",
mask_rolled.nonzero(as_tuple=False).view(-1))
if pool_method != "none" and focal_level > 1:
self.unfolds = nn.ModuleList()
# build relative position bias between local patch and pooled windows
for k in range(focal_level - 1):
stride = 2**k
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
for i in self.focal_window)
# define unfolding operations
self.unfolds += [
nn.Unfold(kernel_size=kernel_size,
stride=stride,
padding=tuple(i // 2 for i in kernel_size))
]
# define unfolding index for focal_level > 0
if k > 0:
mask = torch.zeros(kernel_size)
mask[(2**k) - 1:, (2**k) - 1:] = 1
self.register_buffer(
"valid_ind_unfold_{}".format(k),
mask.flatten(0).nonzero(as_tuple=False).view(-1))
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.proj = nn.Linear(dim, dim)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x_all, mask_all=None):
"""
Args:
x: input features with shape of (B, T, Wh, Ww, C)
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
output: (nW*B, Wh*Ww, C)
"""
x = x_all[0]
B, T, nH, nW, C = x.shape
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
C).permute(4, 0, 1, 2, 3, 5).contiguous()
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
# partition q map
(q_windows, k_windows, v_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
contiguous().view(-1, self.num_heads, T * self.window_size[
0] * self.window_size[1], C // self.num_heads), (q, k, v))
# q(k/v)_windows shape : [16, 4, 225, 128]
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
(k_tl, v_tl) = map(
lambda t: torch.roll(t,
shifts=(-self.expand_size[0], -self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_tr, v_tr) = map(
lambda t: torch.roll(t,
shifts=(-self.expand_size[0], self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_bl, v_bl) = map(
lambda t: torch.roll(t,
shifts=(self.expand_size[0], -self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_br, v_br) = map(
lambda t: torch.roll(t,
shifts=(self.expand_size[0], self.
expand_size[1]),
dims=(2, 3)), (k, v))
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
lambda t: window_partition(t, self.window_size).view(
-1, T, self.window_size[0] * self.window_size[1], self.
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
k_rolled = torch.cat(
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
2).permute(0, 3, 1, 2, 4).contiguous()
v_rolled = torch.cat(
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
2).permute(0, 3, 1, 2, 4).contiguous()
# mask out tokens in current window
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
temp_N = k_rolled.shape[3]
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
C // self.num_heads)
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
C // self.num_heads)
k_rolled = torch.cat((k_windows, k_rolled), 2)
v_rolled = torch.cat((v_windows, v_rolled), 2)
else:
k_rolled = k_windows
v_rolled = v_windows
# q(k/v)_windows shape : [16, 4, 225, 128]
# k_rolled.shape : [16, 4, 5, 165, 128]
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
if self.pool_method != "none" and self.focal_level > 1:
k_pooled = []
v_pooled = []
for k in range(self.focal_level - 1):
stride = 2**k
# B, T, nWh, nWw, C
x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
4).contiguous()
nWh, nWw = x_window_pooled.shape[2:4]
# generate mask for pooled windows
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
# unfold mask: [nWh*nWw//s//s, k*k, 1]
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
view(nWh*nWw // stride // stride, -1, 1)
if k > 0:
valid_ind_unfold_k = getattr(
self, "valid_ind_unfold_{}".format(k))
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
x_window_masks = x_window_masks.masked_fill(
x_window_masks == 0,
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
mask_all[k + 1] = x_window_masks
# generate k and v for pooled windows
qkv_pooled = self.qkv(x_window_pooled).reshape(
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
3).view(3, -1, C, nWh,
nWw).contiguous()
# B*T, C, nWh, nWw
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
# k_pooled_k shape: [5, 512, 4, 4]
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
(k_pooled_k, v_pooled_k) = map(
lambda t: self.unfolds[k]
(t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
unfolds[k].kernel_size[1], -1)
.permute(0, 5, 1, 3, 4, 2).contiguous().view(
-1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
k].kernel_size[1], self.num_heads, C // self.
num_heads).permute(0, 3, 1, 2, 4).contiguous(),
# (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
(k_pooled_k, v_pooled_k))
# k_pooled_k shape : [16, 4, 5, 45, 128]
# select valid unfolding index
if k > 0:
(k_pooled_k, v_pooled_k) = map(
lambda t: t[:, :, :, valid_ind_unfold_k],
(k_pooled_k, v_pooled_k))
k_pooled_k = k_pooled_k.view(
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
self.unfolds[k].kernel_size[1], C // self.num_heads)
v_pooled_k = v_pooled_k.view(
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
self.unfolds[k].kernel_size[1], C // self.num_heads)
k_pooled += [k_pooled_k]
v_pooled += [v_pooled_k]
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
k_all = torch.cat([k_rolled] + k_pooled, 2)
v_all = torch.cat([v_rolled] + v_pooled, 2)
else:
k_all = k_rolled
v_all = v_rolled
N = k_all.shape[-2]
q_windows = q_windows * self.scale
# B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
attn = (q_windows @ k_all.transpose(-2, -1))
# T * 45
window_area = T * self.window_size[0] * self.window_size[1]
# T * 165
window_area_rolled = k_rolled.shape[2]
if self.pool_method != "none" and self.focal_level > 1:
offset = window_area_rolled
for k in range(self.focal_level - 1):
# add attentional mask
# mask_all[1] shape [1, 16, T * 45]
bias = tuple((i + 2**k - 1) for i in self.focal_window)
if mask_all[k + 1] is not None:
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
mask_all[k+1][:, :, None, None, :].repeat(
attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
offset += T * bias[0] * bias[1]
if mask_all[0] is not None:
nW = mask_all[0].shape[0]
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
window_area, N)
attn[:, :, :, :, :
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
None, :, None, :, :]
attn = attn.view(-1, self.num_heads, window_area, N)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
C)
x = self.proj(x)
return x
class TemporalFocalTransformerBlock(nn.Module):
r""" Temporal Focal Transformer Block.
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
window_size (tuple[int]): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
focal_level (int): The number level of focal window.
focal_window (int): Window size of each focal window.
n_vecs (int): Required for F3N.
t2t_params (int): T2T parameters for F3N.
"""
def __init__(self,
dim,
num_heads,
window_size=(5, 9),
mlp_ratio=4.,
qkv_bias=True,
pool_method="fc",
focal_level=2,
focal_window=(5, 9),
norm_layer=nn.LayerNorm,
n_vecs=None,
t2t_params=None):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.expand_size = tuple(i // 2 for i in window_size) # TODO
self.mlp_ratio = mlp_ratio
self.pool_method = pool_method
self.focal_level = focal_level
self.focal_window = focal_window
self.window_size_glo = self.window_size
self.pool_layers = nn.ModuleList()
if self.pool_method != "none":
for k in range(self.focal_level - 1):
window_size_glo = tuple(
math.floor(i / (2**k)) for i in self.window_size_glo)
self.pool_layers.append(
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
self.pool_layers[-1].weight.data.fill_(
1. / (window_size_glo[0] * window_size_glo[1]))
self.pool_layers[-1].bias.data.fill_(0)
self.norm1 = norm_layer(dim)
self.attn = WindowAttention(dim,
expand_size=self.expand_size,
window_size=self.window_size,
focal_window=focal_window,
focal_level=focal_level,
num_heads=num_heads,
qkv_bias=qkv_bias,
pool_method=pool_method)
self.norm2 = norm_layer(dim)
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
def forward(self, x):
output_size = x[1]
x = x[0]
B, T, H, W, C = x.shape
shortcut = x
x = self.norm1(x)
shifted_x = x
x_windows_all = [shifted_x]
x_window_masks_all = [None]
# partition windows tuple(i // 2 for i in window_size)
if self.focal_level > 1 and self.pool_method != "none":
# if we add coarser granularity and the pool method is not none
for k in range(self.focal_level - 1):
window_size_glo = tuple(
math.floor(i / (2**k)) for i in self.window_size_glo)
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
H_pool = pooled_h * window_size_glo[0]
W_pool = pooled_w * window_size_glo[1]
x_level_k = shifted_x
# trim or pad shifted_x depending on the required size
if H > H_pool:
trim_t = (H - H_pool) // 2
trim_b = H - H_pool - trim_t
x_level_k = x_level_k[:, :, trim_t:-trim_b]
elif H < H_pool:
pad_t = (H_pool - H) // 2
pad_b = H_pool - H - pad_t
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
if W > W_pool:
trim_l = (W - W_pool) // 2
trim_r = W - W_pool - trim_l
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
elif W < W_pool:
pad_l = (W_pool - W) // 2
pad_r = W_pool - W - pad_l
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
x_windows_noreshape = window_partition_noreshape(
x_level_k.contiguous(), window_size_glo
) # B, nw, nw, T, window_size, window_size, C
nWh, nWw = x_windows_noreshape.shape[1:3]
x_windows_noreshape = x_windows_noreshape.view(
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
x_windows_pooled = self.pool_layers[k](
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
x_windows_all += [x_windows_pooled]
x_window_masks_all += [None]
# nW*B, T*window_size*window_size, C
attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
# merge windows
attn_windows = attn_windows.view(-1, T, self.window_size[0],
self.window_size[1], C)
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
W) # B T H' W' C
# FFN
x = shortcut + shifted_x
y = self.norm2(x)
x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
B, T, H, W, C)
return x, output_size

View File

View File

@@ -0,0 +1,24 @@
import cv2
import numpy as np
# resize frames
def resize_frames(frames, size=None):
"""
size: (w, h)
"""
if size is not None:
frames = [cv2.resize(f, size) for f in frames]
frames = np.stack(frames, 0)
return frames
# resize frames
def resize_masks(masks, size=None):
"""
size: (w, h)
"""
if size is not None:
masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
masks = np.stack(masks, 0)
return masks