mirror of
https://github.com/gaomingqi/Track-Anything.git
synced 2025-12-15 16:07:51 +01:00
add inpainting
This commit is contained in:
160
inpainter/base_inpainter.py
Normal file
160
inpainter/base_inpainter.py
Normal file
@@ -0,0 +1,160 @@
|
||||
import os
|
||||
import glob
|
||||
from PIL import Image
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
import cv2
|
||||
import importlib
|
||||
import numpy as np
|
||||
from util.tensor_util import resize_frames, resize_masks
|
||||
|
||||
|
||||
class BaseInpainter:
|
||||
def __init__(self, E2FGVI_checkpoint, device) -> None:
|
||||
"""
|
||||
E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support)
|
||||
"""
|
||||
net = importlib.import_module('model.e2fgvi_hq')
|
||||
self.model = net.InpaintGenerator().to(device)
|
||||
self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device))
|
||||
self.model.eval()
|
||||
self.device = device
|
||||
# load configurations
|
||||
with open("inpainter/config/config.yaml", 'r') as stream:
|
||||
config = yaml.safe_load(stream)
|
||||
self.neighbor_stride = config['neighbor_stride']
|
||||
self.num_ref = config['num_ref']
|
||||
self.step = config['step']
|
||||
|
||||
# sample reference frames from the whole video
|
||||
def get_ref_index(self, f, neighbor_ids, length):
|
||||
ref_index = []
|
||||
if self.num_ref == -1:
|
||||
for i in range(0, length, self.step):
|
||||
if i not in neighbor_ids:
|
||||
ref_index.append(i)
|
||||
else:
|
||||
start_idx = max(0, f - self.step * (self.num_ref // 2))
|
||||
end_idx = min(length, f + self.step * (self.num_ref // 2))
|
||||
for i in range(start_idx, end_idx + 1, self.step):
|
||||
if i not in neighbor_ids:
|
||||
if len(ref_index) > self.num_ref:
|
||||
break
|
||||
ref_index.append(i)
|
||||
return ref_index
|
||||
|
||||
def inpaint(self, frames, masks, dilate_radius=15, ratio=1):
|
||||
"""
|
||||
frames: numpy array, T, H, W, 3
|
||||
masks: numpy array, T, H, W
|
||||
dilate_radius: radius when applying dilation on masks
|
||||
ratio: down-sample ratio
|
||||
|
||||
Output:
|
||||
inpainted_frames: numpy array, T, H, W, 3
|
||||
"""
|
||||
assert frames.shape[:3] == masks.shape, 'different size between frames and masks'
|
||||
assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]'
|
||||
masks = masks.copy()
|
||||
masks = np.clip(masks, 0, 1)
|
||||
kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius))
|
||||
masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0)
|
||||
|
||||
T, H, W = masks.shape
|
||||
# size: (w, h)
|
||||
if ratio == 1:
|
||||
size = None
|
||||
else:
|
||||
size = (int(W*ratio), int(H*ratio))
|
||||
|
||||
masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1
|
||||
binary_masks = resize_masks(masks, size)
|
||||
frames = resize_frames(frames, size) # T, H, W, 3
|
||||
# frames and binary_masks are numpy arrays
|
||||
|
||||
h, w = frames.shape[1:3]
|
||||
video_length = T
|
||||
|
||||
# convert to tensor
|
||||
imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1
|
||||
masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0)
|
||||
|
||||
imgs, masks = imgs.to(self.device), masks.to(self.device)
|
||||
comp_frames = [None] * video_length
|
||||
|
||||
for f in range(0, video_length, self.neighbor_stride):
|
||||
neighbor_ids = [
|
||||
i for i in range(max(0, f - self.neighbor_stride),
|
||||
min(video_length, f + self.neighbor_stride + 1))
|
||||
]
|
||||
ref_ids = self.get_ref_index(f, neighbor_ids, video_length)
|
||||
selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :]
|
||||
selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :]
|
||||
with torch.no_grad():
|
||||
masked_imgs = selected_imgs * (1 - selected_masks)
|
||||
mod_size_h = 60
|
||||
mod_size_w = 108
|
||||
h_pad = (mod_size_h - h % mod_size_h) % mod_size_h
|
||||
w_pad = (mod_size_w - w % mod_size_w) % mod_size_w
|
||||
masked_imgs = torch.cat(
|
||||
[masked_imgs, torch.flip(masked_imgs, [3])],
|
||||
3)[:, :, :, :h + h_pad, :]
|
||||
masked_imgs = torch.cat(
|
||||
[masked_imgs, torch.flip(masked_imgs, [4])],
|
||||
4)[:, :, :, :, :w + w_pad]
|
||||
pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids))
|
||||
pred_imgs = pred_imgs[:, :, :h, :w]
|
||||
pred_imgs = (pred_imgs + 1) / 2
|
||||
pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255
|
||||
for i in range(len(neighbor_ids)):
|
||||
idx = neighbor_ids[i]
|
||||
img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * (
|
||||
1 - binary_masks[idx])
|
||||
if comp_frames[idx] is None:
|
||||
comp_frames[idx] = img
|
||||
else:
|
||||
comp_frames[idx] = comp_frames[idx].astype(
|
||||
np.float32) * 0.5 + img.astype(np.float32) * 0.5
|
||||
|
||||
inpainted_frames = np.stack(comp_frames, 0)
|
||||
return inpainted_frames.astype(np.uint8)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg'))
|
||||
frame_path.sort()
|
||||
mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png"))
|
||||
mask_path.sort()
|
||||
save_path = '/ssd1/gaomingqi/results/inpainting/parkour'
|
||||
|
||||
if not os.path.exists(save_path):
|
||||
os.mkdir(save_path)
|
||||
|
||||
frames = []
|
||||
masks = []
|
||||
for fid, mid in zip(frame_path, mask_path):
|
||||
frames.append(Image.open(fid).convert('RGB'))
|
||||
masks.append(Image.open(mid).convert('P'))
|
||||
|
||||
frames = np.stack(frames, 0)
|
||||
masks = np.stack(masks, 0)
|
||||
|
||||
# ----------------------------------------------
|
||||
# how to use
|
||||
# ----------------------------------------------
|
||||
# 1/3: set checkpoint and device
|
||||
checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth'
|
||||
device = 'cuda:2'
|
||||
# 2/3: initialise inpainter
|
||||
base_inpainter = BaseInpainter(checkpoint, device)
|
||||
# 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W)
|
||||
# ratio: (0, 1], ratio for down sample, default value is 1
|
||||
inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.5) # numpy array, T, H, W, 3
|
||||
# ----------------------------------------------
|
||||
# end
|
||||
# ----------------------------------------------
|
||||
# save
|
||||
for ti, inpainted_frame in enumerate(inpainted_frames):
|
||||
frame = Image.fromarray(inpainted_frame).convert('RGB')
|
||||
frame.save(os.path.join(save_path, f'{ti:05d}.jpg'))
|
||||
4
inpainter/config/config.yaml
Normal file
4
inpainter/config/config.yaml
Normal file
@@ -0,0 +1,4 @@
|
||||
# config info for E2FGVI
|
||||
neighbor_stride: 5
|
||||
num_ref: -1
|
||||
step: 10
|
||||
350
inpainter/model/e2fgvi.py
Normal file
350
inpainter/model/e2fgvi.py
Normal file
@@ -0,0 +1,350 @@
|
||||
''' Towards An End-to-End Framework for Video Inpainting
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.modules.flow_comp import SPyNet
|
||||
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
|
||||
from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp
|
||||
from model.modules.spectral_norm import spectral_norm as _spectral_norm
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
def print_network(self):
|
||||
if isinstance(self, list):
|
||||
self = self[0]
|
||||
num_params = 0
|
||||
for param in self.parameters():
|
||||
num_params += param.numel()
|
||||
print(
|
||||
'Network [%s] was created. Total number of parameters: %.1f million. '
|
||||
'To see the architecture, do print(network).' %
|
||||
(type(self).__name__, num_params / 1000000))
|
||||
|
||||
def init_weights(self, init_type='normal', gain=0.02):
|
||||
'''
|
||||
initialize network's weights
|
||||
init_type: normal | xavier | kaiming | orthogonal
|
||||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
||||
'''
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('InstanceNorm2d') != -1:
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
||||
or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == 'xavier':
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'xavier_uniform':
|
||||
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
||||
elif init_type == 'kaiming':
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'none': # uses pytorch's default init method
|
||||
m.reset_parameters()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'initialization method [%s] is not implemented' %
|
||||
init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
self.apply(init_func)
|
||||
|
||||
# propagate to children
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights(init_type, gain)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
self.group = [1, 2, 4, 8, 1]
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
bt, c, h, w = x.size()
|
||||
h, w = h // 4, w // 4
|
||||
out = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i == 8:
|
||||
x0 = out
|
||||
if i > 8 and i % 2 == 0:
|
||||
g = self.group[(i - 8) // 2]
|
||||
x = x0.view(bt, g, -1, h, w)
|
||||
o = out.view(bt, g, -1, h, w)
|
||||
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
||||
out = layer(out)
|
||||
return out
|
||||
|
||||
|
||||
class deconv(nn.Module):
|
||||
def __init__(self,
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
padding=0):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(input_channel,
|
||||
output_channel,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class InpaintGenerator(BaseNetwork):
|
||||
def __init__(self, init_weights=True):
|
||||
super(InpaintGenerator, self).__init__()
|
||||
channel = 256
|
||||
hidden = 512
|
||||
|
||||
# encoder
|
||||
self.encoder = Encoder()
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.Sequential(
|
||||
deconv(channel // 2, 128, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
deconv(64, 64, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# feature propagation module
|
||||
self.feat_prop_module = BidirectionalPropagation(channel // 2)
|
||||
|
||||
# soft split and soft composition
|
||||
kernel_size = (7, 7)
|
||||
padding = (3, 3)
|
||||
stride = (3, 3)
|
||||
output_size = (60, 108)
|
||||
t2t_params = {
|
||||
'kernel_size': kernel_size,
|
||||
'stride': stride,
|
||||
'padding': padding,
|
||||
'output_size': output_size
|
||||
}
|
||||
self.ss = SoftSplit(channel // 2,
|
||||
hidden,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
t2t_param=t2t_params)
|
||||
self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size,
|
||||
stride, padding)
|
||||
|
||||
n_vecs = 1
|
||||
for i, d in enumerate(kernel_size):
|
||||
n_vecs *= int((output_size[i] + 2 * padding[i] -
|
||||
(d - 1) - 1) / stride[i] + 1)
|
||||
|
||||
blocks = []
|
||||
depths = 8
|
||||
num_heads = [4] * depths
|
||||
window_size = [(5, 9)] * depths
|
||||
focal_windows = [(5, 9)] * depths
|
||||
focal_levels = [2] * depths
|
||||
pool_method = "fc"
|
||||
|
||||
for i in range(depths):
|
||||
blocks.append(
|
||||
TemporalFocalTransformerBlock(dim=hidden,
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size[i],
|
||||
focal_level=focal_levels[i],
|
||||
focal_window=focal_windows[i],
|
||||
n_vecs=n_vecs,
|
||||
t2t_params=t2t_params,
|
||||
pool_method=pool_method))
|
||||
self.transformer = nn.Sequential(*blocks)
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
# Need to initial the weights of MSDeformAttn specifically
|
||||
for m in self.modules():
|
||||
if isinstance(m, SecondOrderDeformableAlignment):
|
||||
m.init_offset()
|
||||
|
||||
# flow completion network
|
||||
self.update_spynet = SPyNet()
|
||||
|
||||
def forward_bidirect_flow(self, masked_local_frames):
|
||||
b, l_t, c, h, w = masked_local_frames.size()
|
||||
|
||||
# compute forward and backward flows of masked frames
|
||||
masked_local_frames = F.interpolate(masked_local_frames.view(
|
||||
-1, c, h, w),
|
||||
scale_factor=1 / 4,
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
recompute_scale_factor=True)
|
||||
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
|
||||
w // 4)
|
||||
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
|
||||
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
|
||||
|
||||
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
|
||||
w // 4)
|
||||
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
|
||||
w // 4)
|
||||
|
||||
return pred_flows_forward, pred_flows_backward
|
||||
|
||||
def forward(self, masked_frames, num_local_frames):
|
||||
l_t = num_local_frames
|
||||
b, t, ori_c, ori_h, ori_w = masked_frames.size()
|
||||
|
||||
# normalization before feeding into the flow completion module
|
||||
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
|
||||
pred_flows = self.forward_bidirect_flow(masked_local_frames)
|
||||
|
||||
# extracting features and performing the feature propagation on local features
|
||||
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
|
||||
_, c, h, w = enc_feat.size()
|
||||
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
||||
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
||||
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
|
||||
pred_flows[1])
|
||||
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
||||
|
||||
# content hallucination through stacking multiple temporal focal transformer blocks
|
||||
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b)
|
||||
trans_feat = self.transformer(trans_feat)
|
||||
trans_feat = self.sc(trans_feat, t)
|
||||
trans_feat = trans_feat.view(b, t, -1, h, w)
|
||||
enc_feat = enc_feat + trans_feat
|
||||
|
||||
# decode frames from features
|
||||
output = self.decoder(enc_feat.view(b * t, c, h, w))
|
||||
output = torch.tanh(output)
|
||||
return output, pred_flows
|
||||
|
||||
|
||||
# ######################################################################
|
||||
# Discriminator for Temporal Patch GAN
|
||||
# ######################################################################
|
||||
|
||||
|
||||
class Discriminator(BaseNetwork):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
use_sigmoid=False,
|
||||
use_spectral_norm=True,
|
||||
init_weights=True):
|
||||
super(Discriminator, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
nf = 32
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(
|
||||
nn.Conv3d(in_channels=in_channels,
|
||||
out_channels=nf * 1,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(64, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 1,
|
||||
nf * 2,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(128, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 2,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2)))
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, xs):
|
||||
# T, C, H, W = xs.shape (old)
|
||||
# B, T, C, H, W (new)
|
||||
xs_t = torch.transpose(xs, 1, 2)
|
||||
feat = self.conv(xs_t)
|
||||
if self.use_sigmoid:
|
||||
feat = torch.sigmoid(feat)
|
||||
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
||||
return out
|
||||
|
||||
|
||||
def spectral_norm(module, mode=True):
|
||||
if mode:
|
||||
return _spectral_norm(module)
|
||||
return module
|
||||
350
inpainter/model/e2fgvi_hq.py
Normal file
350
inpainter/model/e2fgvi_hq.py
Normal file
@@ -0,0 +1,350 @@
|
||||
''' Towards An End-to-End Framework for Video Inpainting
|
||||
'''
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from model.modules.flow_comp import SPyNet
|
||||
from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment
|
||||
from model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp
|
||||
from model.modules.spectral_norm import spectral_norm as _spectral_norm
|
||||
|
||||
|
||||
class BaseNetwork(nn.Module):
|
||||
def __init__(self):
|
||||
super(BaseNetwork, self).__init__()
|
||||
|
||||
def print_network(self):
|
||||
if isinstance(self, list):
|
||||
self = self[0]
|
||||
num_params = 0
|
||||
for param in self.parameters():
|
||||
num_params += param.numel()
|
||||
print(
|
||||
'Network [%s] was created. Total number of parameters: %.1f million. '
|
||||
'To see the architecture, do print(network).' %
|
||||
(type(self).__name__, num_params / 1000000))
|
||||
|
||||
def init_weights(self, init_type='normal', gain=0.02):
|
||||
'''
|
||||
initialize network's weights
|
||||
init_type: normal | xavier | kaiming | orthogonal
|
||||
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
|
||||
'''
|
||||
def init_func(m):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('InstanceNorm2d') != -1:
|
||||
if hasattr(m, 'weight') and m.weight is not None:
|
||||
nn.init.constant_(m.weight.data, 1.0)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
elif hasattr(m, 'weight') and (classname.find('Conv') != -1
|
||||
or classname.find('Linear') != -1):
|
||||
if init_type == 'normal':
|
||||
nn.init.normal_(m.weight.data, 0.0, gain)
|
||||
elif init_type == 'xavier':
|
||||
nn.init.xavier_normal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'xavier_uniform':
|
||||
nn.init.xavier_uniform_(m.weight.data, gain=1.0)
|
||||
elif init_type == 'kaiming':
|
||||
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
||||
elif init_type == 'orthogonal':
|
||||
nn.init.orthogonal_(m.weight.data, gain=gain)
|
||||
elif init_type == 'none': # uses pytorch's default init method
|
||||
m.reset_parameters()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
'initialization method [%s] is not implemented' %
|
||||
init_type)
|
||||
if hasattr(m, 'bias') and m.bias is not None:
|
||||
nn.init.constant_(m.bias.data, 0.0)
|
||||
|
||||
self.apply(init_func)
|
||||
|
||||
# propagate to children
|
||||
for m in self.children():
|
||||
if hasattr(m, 'init_weights'):
|
||||
m.init_weights(init_type, gain)
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self):
|
||||
super(Encoder, self).__init__()
|
||||
self.group = [1, 2, 4, 8, 1]
|
||||
self.layers = nn.ModuleList([
|
||||
nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
|
||||
nn.LeakyReLU(0.2, inplace=True)
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
bt, c, _, _ = x.size()
|
||||
# h, w = h//4, w//4
|
||||
out = x
|
||||
for i, layer in enumerate(self.layers):
|
||||
if i == 8:
|
||||
x0 = out
|
||||
_, _, h, w = x0.size()
|
||||
if i > 8 and i % 2 == 0:
|
||||
g = self.group[(i - 8) // 2]
|
||||
x = x0.view(bt, g, -1, h, w)
|
||||
o = out.view(bt, g, -1, h, w)
|
||||
out = torch.cat([x, o], 2).view(bt, -1, h, w)
|
||||
out = layer(out)
|
||||
return out
|
||||
|
||||
|
||||
class deconv(nn.Module):
|
||||
def __init__(self,
|
||||
input_channel,
|
||||
output_channel,
|
||||
kernel_size=3,
|
||||
padding=0):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(input_channel,
|
||||
output_channel,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=True)
|
||||
return self.conv(x)
|
||||
|
||||
|
||||
class InpaintGenerator(BaseNetwork):
|
||||
def __init__(self, init_weights=True):
|
||||
super(InpaintGenerator, self).__init__()
|
||||
channel = 256
|
||||
hidden = 512
|
||||
|
||||
# encoder
|
||||
self.encoder = Encoder()
|
||||
|
||||
# decoder
|
||||
self.decoder = nn.Sequential(
|
||||
deconv(channel // 2, 128, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
deconv(64, 64, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# feature propagation module
|
||||
self.feat_prop_module = BidirectionalPropagation(channel // 2)
|
||||
|
||||
# soft split and soft composition
|
||||
kernel_size = (7, 7)
|
||||
padding = (3, 3)
|
||||
stride = (3, 3)
|
||||
output_size = (60, 108)
|
||||
t2t_params = {
|
||||
'kernel_size': kernel_size,
|
||||
'stride': stride,
|
||||
'padding': padding
|
||||
}
|
||||
self.ss = SoftSplit(channel // 2,
|
||||
hidden,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding,
|
||||
t2t_param=t2t_params)
|
||||
self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding)
|
||||
|
||||
n_vecs = 1
|
||||
for i, d in enumerate(kernel_size):
|
||||
n_vecs *= int((output_size[i] + 2 * padding[i] -
|
||||
(d - 1) - 1) / stride[i] + 1)
|
||||
|
||||
blocks = []
|
||||
depths = 8
|
||||
num_heads = [4] * depths
|
||||
window_size = [(5, 9)] * depths
|
||||
focal_windows = [(5, 9)] * depths
|
||||
focal_levels = [2] * depths
|
||||
pool_method = "fc"
|
||||
|
||||
for i in range(depths):
|
||||
blocks.append(
|
||||
TemporalFocalTransformerBlock(dim=hidden,
|
||||
num_heads=num_heads[i],
|
||||
window_size=window_size[i],
|
||||
focal_level=focal_levels[i],
|
||||
focal_window=focal_windows[i],
|
||||
n_vecs=n_vecs,
|
||||
t2t_params=t2t_params,
|
||||
pool_method=pool_method))
|
||||
self.transformer = nn.Sequential(*blocks)
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
# Need to initial the weights of MSDeformAttn specifically
|
||||
for m in self.modules():
|
||||
if isinstance(m, SecondOrderDeformableAlignment):
|
||||
m.init_offset()
|
||||
|
||||
# flow completion network
|
||||
self.update_spynet = SPyNet()
|
||||
|
||||
def forward_bidirect_flow(self, masked_local_frames):
|
||||
b, l_t, c, h, w = masked_local_frames.size()
|
||||
|
||||
# compute forward and backward flows of masked frames
|
||||
masked_local_frames = F.interpolate(masked_local_frames.view(
|
||||
-1, c, h, w),
|
||||
scale_factor=1 / 4,
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
recompute_scale_factor=True)
|
||||
masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4,
|
||||
w // 4)
|
||||
mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
pred_flows_forward = self.update_spynet(mlf_1, mlf_2)
|
||||
pred_flows_backward = self.update_spynet(mlf_2, mlf_1)
|
||||
|
||||
pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4,
|
||||
w // 4)
|
||||
pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4,
|
||||
w // 4)
|
||||
|
||||
return pred_flows_forward, pred_flows_backward
|
||||
|
||||
def forward(self, masked_frames, num_local_frames):
|
||||
l_t = num_local_frames
|
||||
b, t, ori_c, ori_h, ori_w = masked_frames.size()
|
||||
|
||||
# normalization before feeding into the flow completion module
|
||||
masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2
|
||||
pred_flows = self.forward_bidirect_flow(masked_local_frames)
|
||||
|
||||
# extracting features and performing the feature propagation on local features
|
||||
enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w))
|
||||
_, c, h, w = enc_feat.size()
|
||||
fold_output_size = (h, w)
|
||||
local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
|
||||
ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
|
||||
local_feat = self.feat_prop_module(local_feat, pred_flows[0],
|
||||
pred_flows[1])
|
||||
enc_feat = torch.cat((local_feat, ref_feat), dim=1)
|
||||
|
||||
# content hallucination through stacking multiple temporal focal transformer blocks
|
||||
trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size)
|
||||
trans_feat = self.transformer([trans_feat, fold_output_size])
|
||||
trans_feat = self.sc(trans_feat[0], t, fold_output_size)
|
||||
trans_feat = trans_feat.view(b, t, -1, h, w)
|
||||
enc_feat = enc_feat + trans_feat
|
||||
|
||||
# decode frames from features
|
||||
output = self.decoder(enc_feat.view(b * t, c, h, w))
|
||||
output = torch.tanh(output)
|
||||
return output, pred_flows
|
||||
|
||||
|
||||
# ######################################################################
|
||||
# Discriminator for Temporal Patch GAN
|
||||
# ######################################################################
|
||||
|
||||
|
||||
class Discriminator(BaseNetwork):
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
use_sigmoid=False,
|
||||
use_spectral_norm=True,
|
||||
init_weights=True):
|
||||
super(Discriminator, self).__init__()
|
||||
self.use_sigmoid = use_sigmoid
|
||||
nf = 32
|
||||
|
||||
self.conv = nn.Sequential(
|
||||
spectral_norm(
|
||||
nn.Conv3d(in_channels=in_channels,
|
||||
out_channels=nf * 1,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=1,
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(64, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 1,
|
||||
nf * 2,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(128, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 2,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
spectral_norm(
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2),
|
||||
bias=not use_spectral_norm), use_spectral_norm),
|
||||
# nn.InstanceNorm2d(256, track_running_stats=False),
|
||||
nn.LeakyReLU(0.2, inplace=True),
|
||||
nn.Conv3d(nf * 4,
|
||||
nf * 4,
|
||||
kernel_size=(3, 5, 5),
|
||||
stride=(1, 2, 2),
|
||||
padding=(1, 2, 2)))
|
||||
|
||||
if init_weights:
|
||||
self.init_weights()
|
||||
|
||||
def forward(self, xs):
|
||||
# T, C, H, W = xs.shape (old)
|
||||
# B, T, C, H, W (new)
|
||||
xs_t = torch.transpose(xs, 1, 2)
|
||||
feat = self.conv(xs_t)
|
||||
if self.use_sigmoid:
|
||||
feat = torch.sigmoid(feat)
|
||||
out = torch.transpose(feat, 1, 2) # B, T, C, H, W
|
||||
return out
|
||||
|
||||
|
||||
def spectral_norm(module, mode=True):
|
||||
if mode:
|
||||
return _spectral_norm(module)
|
||||
return module
|
||||
149
inpainter/model/modules/feat_prop.py
Normal file
149
inpainter/model/modules/feat_prop.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d
|
||||
from mmengine.model import constant_init
|
||||
|
||||
from model.modules.flow_comp import flow_warp
|
||||
|
||||
|
||||
class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
|
||||
"""Second-order deformable alignment module."""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
|
||||
|
||||
super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
|
||||
|
||||
self.conv_offset = nn.Sequential(
|
||||
nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
|
||||
)
|
||||
|
||||
self.init_offset()
|
||||
|
||||
def init_offset(self):
|
||||
constant_init(self.conv_offset[-1], val=0, bias=0)
|
||||
|
||||
def forward(self, x, extra_feat, flow_1, flow_2):
|
||||
extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1)
|
||||
out = self.conv_offset(extra_feat)
|
||||
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
||||
|
||||
# offset
|
||||
offset = self.max_residue_magnitude * torch.tanh(
|
||||
torch.cat((o1, o2), dim=1))
|
||||
offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
|
||||
offset_1 = offset_1 + flow_1.flip(1).repeat(1,
|
||||
offset_1.size(1) // 2, 1,
|
||||
1)
|
||||
offset_2 = offset_2 + flow_2.flip(1).repeat(1,
|
||||
offset_2.size(1) // 2, 1,
|
||||
1)
|
||||
offset = torch.cat([offset_1, offset_2], dim=1)
|
||||
|
||||
# mask
|
||||
mask = torch.sigmoid(mask)
|
||||
|
||||
return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias,
|
||||
self.stride, self.padding,
|
||||
self.dilation, self.groups,
|
||||
self.deform_groups)
|
||||
|
||||
|
||||
class BidirectionalPropagation(nn.Module):
|
||||
def __init__(self, channel):
|
||||
super(BidirectionalPropagation, self).__init__()
|
||||
modules = ['backward_', 'forward_']
|
||||
self.deform_align = nn.ModuleDict()
|
||||
self.backbone = nn.ModuleDict()
|
||||
self.channel = channel
|
||||
|
||||
for i, module in enumerate(modules):
|
||||
self.deform_align[module] = SecondOrderDeformableAlignment(
|
||||
2 * channel, channel, 3, padding=1, deform_groups=16)
|
||||
|
||||
self.backbone[module] = nn.Sequential(
|
||||
nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
|
||||
nn.LeakyReLU(negative_slope=0.1, inplace=True),
|
||||
nn.Conv2d(channel, channel, 3, 1, 1),
|
||||
)
|
||||
|
||||
self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
|
||||
|
||||
def forward(self, x, flows_backward, flows_forward):
|
||||
"""
|
||||
x shape : [b, t, c, h, w]
|
||||
return [b, t, c, h, w]
|
||||
"""
|
||||
b, t, c, h, w = x.shape
|
||||
feats = {}
|
||||
feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
|
||||
|
||||
for module_name in ['backward_', 'forward_']:
|
||||
|
||||
feats[module_name] = []
|
||||
|
||||
frame_idx = range(0, t)
|
||||
flow_idx = range(-1, t - 1)
|
||||
mapping_idx = list(range(0, len(feats['spatial'])))
|
||||
mapping_idx += mapping_idx[::-1]
|
||||
|
||||
if 'backward' in module_name:
|
||||
frame_idx = frame_idx[::-1]
|
||||
flows = flows_backward
|
||||
else:
|
||||
flows = flows_forward
|
||||
|
||||
feat_prop = x.new_zeros(b, self.channel, h, w)
|
||||
for i, idx in enumerate(frame_idx):
|
||||
feat_current = feats['spatial'][mapping_idx[idx]]
|
||||
|
||||
if i > 0:
|
||||
flow_n1 = flows[:, flow_idx[i], :, :, :]
|
||||
cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1))
|
||||
|
||||
# initialize second-order features
|
||||
feat_n2 = torch.zeros_like(feat_prop)
|
||||
flow_n2 = torch.zeros_like(flow_n1)
|
||||
cond_n2 = torch.zeros_like(cond_n1)
|
||||
if i > 1:
|
||||
feat_n2 = feats[module_name][-2]
|
||||
flow_n2 = flows[:, flow_idx[i - 1], :, :, :]
|
||||
flow_n2 = flow_n1 + flow_warp(
|
||||
flow_n2, flow_n1.permute(0, 2, 3, 1))
|
||||
cond_n2 = flow_warp(feat_n2,
|
||||
flow_n2.permute(0, 2, 3, 1))
|
||||
|
||||
cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1)
|
||||
feat_prop = torch.cat([feat_prop, feat_n2], dim=1)
|
||||
feat_prop = self.deform_align[module_name](feat_prop, cond,
|
||||
flow_n1,
|
||||
flow_n2)
|
||||
|
||||
feat = [feat_current] + [
|
||||
feats[k][idx]
|
||||
for k in feats if k not in ['spatial', module_name]
|
||||
] + [feat_prop]
|
||||
|
||||
feat = torch.cat(feat, dim=1)
|
||||
feat_prop = feat_prop + self.backbone[module_name](feat)
|
||||
feats[module_name].append(feat_prop)
|
||||
|
||||
if 'backward' in module_name:
|
||||
feats[module_name] = feats[module_name][::-1]
|
||||
|
||||
outputs = []
|
||||
for i in range(0, t):
|
||||
align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
|
||||
align_feats = torch.cat(align_feats, dim=1)
|
||||
outputs.append(self.fusion(align_feats))
|
||||
|
||||
return torch.stack(outputs, dim=1) + x
|
||||
450
inpainter/model/modules/flow_comp.py
Normal file
450
inpainter/model/modules/flow_comp.py
Normal file
@@ -0,0 +1,450 @@
|
||||
import numpy as np
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmengine.runner import load_checkpoint
|
||||
|
||||
|
||||
class FlowCompletionLoss(nn.Module):
|
||||
"""Flow completion loss"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fix_spynet = SPyNet()
|
||||
for p in self.fix_spynet.parameters():
|
||||
p.requires_grad = False
|
||||
|
||||
self.l1_criterion = nn.L1Loss()
|
||||
|
||||
def forward(self, pred_flows, gt_local_frames):
|
||||
b, l_t, c, h, w = gt_local_frames.size()
|
||||
|
||||
with torch.no_grad():
|
||||
# compute gt forward and backward flows
|
||||
gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w),
|
||||
scale_factor=1 / 4,
|
||||
mode='bilinear',
|
||||
align_corners=True,
|
||||
recompute_scale_factor=True)
|
||||
gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4)
|
||||
gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(
|
||||
-1, c, h // 4, w // 4)
|
||||
gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2)
|
||||
gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1)
|
||||
|
||||
# calculate loss for flow completion
|
||||
forward_flow_loss = self.l1_criterion(
|
||||
pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward)
|
||||
backward_flow_loss = self.l1_criterion(
|
||||
pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward)
|
||||
flow_loss = forward_flow_loss + backward_flow_loss
|
||||
|
||||
return flow_loss
|
||||
|
||||
|
||||
class SPyNet(nn.Module):
|
||||
"""SPyNet network structure.
|
||||
The difference to the SPyNet in [tof.py] is that
|
||||
1. more SPyNetBasicModule is used in this version, and
|
||||
2. no batch normalization is used in this version.
|
||||
Paper:
|
||||
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
|
||||
Args:
|
||||
pretrained (str): path for pre-trained SPyNet. Default: None.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
use_pretrain=True,
|
||||
pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth'
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.basic_module = nn.ModuleList(
|
||||
[SPyNetBasicModule() for _ in range(6)])
|
||||
|
||||
if use_pretrain:
|
||||
if isinstance(pretrained, str):
|
||||
print("load pretrained SPyNet...")
|
||||
load_checkpoint(self, pretrained, strict=True)
|
||||
elif pretrained is not None:
|
||||
raise TypeError('[pretrained] should be str or None, '
|
||||
f'but got {type(pretrained)}.')
|
||||
|
||||
self.register_buffer(
|
||||
'mean',
|
||||
torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
||||
self.register_buffer(
|
||||
'std',
|
||||
torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
||||
|
||||
def compute_flow(self, ref, supp):
|
||||
"""Compute flow from ref to supp.
|
||||
Note that in this function, the images are already resized to a
|
||||
multiple of 32.
|
||||
Args:
|
||||
ref (Tensor): Reference image with shape of (n, 3, h, w).
|
||||
supp (Tensor): Supporting image with shape of (n, 3, h, w).
|
||||
Returns:
|
||||
Tensor: Estimated optical flow: (n, 2, h, w).
|
||||
"""
|
||||
n, _, h, w = ref.size()
|
||||
|
||||
# normalize the input images
|
||||
ref = [(ref - self.mean) / self.std]
|
||||
supp = [(supp - self.mean) / self.std]
|
||||
|
||||
# generate downsampled frames
|
||||
for level in range(5):
|
||||
ref.append(
|
||||
F.avg_pool2d(input=ref[-1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
count_include_pad=False))
|
||||
supp.append(
|
||||
F.avg_pool2d(input=supp[-1],
|
||||
kernel_size=2,
|
||||
stride=2,
|
||||
count_include_pad=False))
|
||||
ref = ref[::-1]
|
||||
supp = supp[::-1]
|
||||
|
||||
# flow computation
|
||||
flow = ref[0].new_zeros(n, 2, h // 32, w // 32)
|
||||
for level in range(len(ref)):
|
||||
if level == 0:
|
||||
flow_up = flow
|
||||
else:
|
||||
flow_up = F.interpolate(input=flow,
|
||||
scale_factor=2,
|
||||
mode='bilinear',
|
||||
align_corners=True) * 2.0
|
||||
|
||||
# add the residue to the upsampled flow
|
||||
flow = flow_up + self.basic_module[level](torch.cat([
|
||||
ref[level],
|
||||
flow_warp(supp[level],
|
||||
flow_up.permute(0, 2, 3, 1).contiguous(),
|
||||
padding_mode='border'), flow_up
|
||||
], 1))
|
||||
|
||||
return flow
|
||||
|
||||
def forward(self, ref, supp):
|
||||
"""Forward function of SPyNet.
|
||||
This function computes the optical flow from ref to supp.
|
||||
Args:
|
||||
ref (Tensor): Reference image with shape of (n, 3, h, w).
|
||||
supp (Tensor): Supporting image with shape of (n, 3, h, w).
|
||||
Returns:
|
||||
Tensor: Estimated optical flow: (n, 2, h, w).
|
||||
"""
|
||||
|
||||
# upsize to a multiple of 32
|
||||
h, w = ref.shape[2:4]
|
||||
w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1)
|
||||
h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1)
|
||||
ref = F.interpolate(input=ref,
|
||||
size=(h_up, w_up),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
supp = F.interpolate(input=supp,
|
||||
size=(h_up, w_up),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# compute flow, and resize back to the original resolution
|
||||
flow = F.interpolate(input=self.compute_flow(ref, supp),
|
||||
size=(h, w),
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
# adjust the flow values
|
||||
flow[:, 0, :, :] *= float(w) / float(w_up)
|
||||
flow[:, 1, :, :] *= float(h) / float(h_up)
|
||||
|
||||
return flow
|
||||
|
||||
|
||||
class SPyNetBasicModule(nn.Module):
|
||||
"""Basic Module for SPyNet.
|
||||
Paper:
|
||||
Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.basic_module = nn.Sequential(
|
||||
ConvModule(in_channels=8,
|
||||
out_channels=32,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU')),
|
||||
ConvModule(in_channels=32,
|
||||
out_channels=64,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU')),
|
||||
ConvModule(in_channels=64,
|
||||
out_channels=32,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU')),
|
||||
ConvModule(in_channels=32,
|
||||
out_channels=16,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU')),
|
||||
ConvModule(in_channels=16,
|
||||
out_channels=2,
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=3,
|
||||
norm_cfg=None,
|
||||
act_cfg=None))
|
||||
|
||||
def forward(self, tensor_input):
|
||||
"""
|
||||
Args:
|
||||
tensor_input (Tensor): Input tensor with shape (b, 8, h, w).
|
||||
8 channels contain:
|
||||
[reference image (3), neighbor image (3), initial flow (2)].
|
||||
Returns:
|
||||
Tensor: Refined flow with shape (b, 2, h, w)
|
||||
"""
|
||||
return self.basic_module(tensor_input)
|
||||
|
||||
|
||||
# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
|
||||
def make_colorwheel():
|
||||
"""
|
||||
Generates a color wheel for optical flow visualization as presented in:
|
||||
Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
|
||||
URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
|
||||
|
||||
Code follows the original C++ source code of Daniel Scharstein.
|
||||
Code follows the the Matlab source code of Deqing Sun.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Color wheel
|
||||
"""
|
||||
|
||||
RY = 15
|
||||
YG = 6
|
||||
GC = 4
|
||||
CB = 11
|
||||
BM = 13
|
||||
MR = 6
|
||||
|
||||
ncols = RY + YG + GC + CB + BM + MR
|
||||
colorwheel = np.zeros((ncols, 3))
|
||||
col = 0
|
||||
|
||||
# RY
|
||||
colorwheel[0:RY, 0] = 255
|
||||
colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
|
||||
col = col + RY
|
||||
# YG
|
||||
colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
|
||||
colorwheel[col:col + YG, 1] = 255
|
||||
col = col + YG
|
||||
# GC
|
||||
colorwheel[col:col + GC, 1] = 255
|
||||
colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
|
||||
col = col + GC
|
||||
# CB
|
||||
colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
|
||||
colorwheel[col:col + CB, 2] = 255
|
||||
col = col + CB
|
||||
# BM
|
||||
colorwheel[col:col + BM, 2] = 255
|
||||
colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
|
||||
col = col + BM
|
||||
# MR
|
||||
colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
|
||||
colorwheel[col:col + MR, 0] = 255
|
||||
return colorwheel
|
||||
|
||||
|
||||
def flow_uv_to_colors(u, v, convert_to_bgr=False):
|
||||
"""
|
||||
Applies the flow color wheel to (possibly clipped) flow components u and v.
|
||||
|
||||
According to the C++ source code of Daniel Scharstein
|
||||
According to the Matlab source code of Deqing Sun
|
||||
|
||||
Args:
|
||||
u (np.ndarray): Input horizontal flow of shape [H,W]
|
||||
v (np.ndarray): Input vertical flow of shape [H,W]
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
|
||||
colorwheel = make_colorwheel() # shape [55x3]
|
||||
ncols = colorwheel.shape[0]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
a = np.arctan2(-v, -u) / np.pi
|
||||
fk = (a + 1) / 2 * (ncols - 1)
|
||||
k0 = np.floor(fk).astype(np.int32)
|
||||
k1 = k0 + 1
|
||||
k1[k1 == ncols] = 0
|
||||
f = fk - k0
|
||||
for i in range(colorwheel.shape[1]):
|
||||
tmp = colorwheel[:, i]
|
||||
col0 = tmp[k0] / 255.0
|
||||
col1 = tmp[k1] / 255.0
|
||||
col = (1 - f) * col0 + f * col1
|
||||
idx = (rad <= 1)
|
||||
col[idx] = 1 - rad[idx] * (1 - col[idx])
|
||||
col[~idx] = col[~idx] * 0.75 # out of range
|
||||
# Note the 2-i => BGR instead of RGB
|
||||
ch_idx = 2 - i if convert_to_bgr else i
|
||||
flow_image[:, :, ch_idx] = np.floor(255 * col)
|
||||
return flow_image
|
||||
|
||||
|
||||
def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
|
||||
"""
|
||||
Expects a two dimensional flow image of shape.
|
||||
|
||||
Args:
|
||||
flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
|
||||
clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
|
||||
convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Flow visualization image of shape [H,W,3]
|
||||
"""
|
||||
assert flow_uv.ndim == 3, 'input flow must have three dimensions'
|
||||
assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
|
||||
if clip_flow is not None:
|
||||
flow_uv = np.clip(flow_uv, 0, clip_flow)
|
||||
u = flow_uv[:, :, 0]
|
||||
v = flow_uv[:, :, 1]
|
||||
rad = np.sqrt(np.square(u) + np.square(v))
|
||||
rad_max = np.max(rad)
|
||||
epsilon = 1e-5
|
||||
u = u / (rad_max + epsilon)
|
||||
v = v / (rad_max + epsilon)
|
||||
return flow_uv_to_colors(u, v, convert_to_bgr)
|
||||
|
||||
|
||||
def flow_warp(x,
|
||||
flow,
|
||||
interpolation='bilinear',
|
||||
padding_mode='zeros',
|
||||
align_corners=True):
|
||||
"""Warp an image or a feature map with optical flow.
|
||||
Args:
|
||||
x (Tensor): Tensor with size (n, c, h, w).
|
||||
flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
|
||||
a two-channel, denoting the width and height relative offsets.
|
||||
Note that the values are not normalized to [-1, 1].
|
||||
interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
|
||||
Default: 'bilinear'.
|
||||
padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
|
||||
Default: 'zeros'.
|
||||
align_corners (bool): Whether align corners. Default: True.
|
||||
Returns:
|
||||
Tensor: Warped image or feature map.
|
||||
"""
|
||||
if x.size()[-2:] != flow.size()[1:3]:
|
||||
raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
|
||||
f'flow ({flow.size()[1:3]}) are not the same.')
|
||||
_, _, h, w = x.size()
|
||||
# create mesh grid
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
|
||||
grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
|
||||
grid.requires_grad = False
|
||||
|
||||
grid_flow = grid + flow
|
||||
# scale grid_flow to [-1,1]
|
||||
grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
|
||||
grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
|
||||
grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
|
||||
output = F.grid_sample(x,
|
||||
grid_flow,
|
||||
mode=interpolation,
|
||||
padding_mode=padding_mode,
|
||||
align_corners=align_corners)
|
||||
return output
|
||||
|
||||
|
||||
def initial_mask_flow(mask):
|
||||
"""
|
||||
mask 1 indicates valid pixel 0 indicates unknown pixel
|
||||
"""
|
||||
B, T, C, H, W = mask.shape
|
||||
|
||||
# calculate relative position
|
||||
grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
|
||||
|
||||
grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask)
|
||||
abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :])
|
||||
relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :])
|
||||
|
||||
abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None])
|
||||
relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None])
|
||||
|
||||
# calculate the nearest indices
|
||||
pos_up = mask.unsqueeze(3).repeat(
|
||||
1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * (
|
||||
relative_pos_y <= H)[None, None, None]
|
||||
nearest_indice_up = pos_up.max(dim=4)[1]
|
||||
|
||||
pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[
|
||||
None, None, None] * (relative_pos_y <= H)[None, None, None]
|
||||
nearest_indice_down = (pos_down).max(dim=4)[1]
|
||||
|
||||
pos_left = mask.unsqueeze(4).repeat(
|
||||
1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * (
|
||||
relative_pos_x <= W)[None, None, None]
|
||||
nearest_indice_left = (pos_left).max(dim=5)[1]
|
||||
|
||||
pos_right = mask.unsqueeze(4).repeat(
|
||||
1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * (
|
||||
relative_pos_x <= W)[None, None, None]
|
||||
nearest_indice_right = (pos_right).max(dim=5)[1]
|
||||
|
||||
# NOTE: IMPORTANT !!! depending on how to use this offset
|
||||
initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3)
|
||||
initial_offset_down = nearest_indice_down - grid_y[None, None, None]
|
||||
|
||||
initial_offset_left = -(nearest_indice_left -
|
||||
grid_x[None, None, None]).flip(4)
|
||||
initial_offset_right = nearest_indice_right - grid_x[None, None, None]
|
||||
|
||||
# nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1]
|
||||
# initial_offset_x = nearest_indice_x - grid_x
|
||||
|
||||
# handle the boundary cases
|
||||
final_offset_down = (initial_offset_down < 0) * initial_offset_up + (
|
||||
initial_offset_down > 0) * initial_offset_down
|
||||
final_offset_up = (initial_offset_up > 0) * initial_offset_down + (
|
||||
initial_offset_up < 0) * initial_offset_up
|
||||
final_offset_right = (initial_offset_right < 0) * initial_offset_left + (
|
||||
initial_offset_right > 0) * initial_offset_right
|
||||
final_offset_left = (initial_offset_left > 0) * initial_offset_right + (
|
||||
initial_offset_left < 0) * initial_offset_left
|
||||
zero_offset = torch.zeros_like(final_offset_down)
|
||||
# out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2)
|
||||
out = torch.cat([
|
||||
zero_offset, final_offset_left, zero_offset, final_offset_right,
|
||||
final_offset_up, zero_offset, final_offset_down, zero_offset
|
||||
],
|
||||
dim=2)
|
||||
|
||||
return out
|
||||
288
inpainter/model/modules/spectral_norm.py
Normal file
288
inpainter/model/modules/spectral_norm.py
Normal file
@@ -0,0 +1,288 @@
|
||||
"""
|
||||
Spectral Normalization from https://arxiv.org/abs/1802.05957
|
||||
"""
|
||||
import torch
|
||||
from torch.nn.functional import normalize
|
||||
|
||||
|
||||
class SpectralNorm(object):
|
||||
# Invariant before and after each forward call:
|
||||
# u = normalize(W @ v)
|
||||
# NB: At initialization, this invariant is not enforced
|
||||
|
||||
_version = 1
|
||||
|
||||
# At version 1:
|
||||
# made `W` not a buffer,
|
||||
# added `v` as a buffer, and
|
||||
# made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
|
||||
|
||||
def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
|
||||
self.name = name
|
||||
self.dim = dim
|
||||
if n_power_iterations <= 0:
|
||||
raise ValueError(
|
||||
'Expected n_power_iterations to be positive, but '
|
||||
'got n_power_iterations={}'.format(n_power_iterations))
|
||||
self.n_power_iterations = n_power_iterations
|
||||
self.eps = eps
|
||||
|
||||
def reshape_weight_to_matrix(self, weight):
|
||||
weight_mat = weight
|
||||
if self.dim != 0:
|
||||
# permute dim to front
|
||||
weight_mat = weight_mat.permute(
|
||||
self.dim,
|
||||
*[d for d in range(weight_mat.dim()) if d != self.dim])
|
||||
height = weight_mat.size(0)
|
||||
return weight_mat.reshape(height, -1)
|
||||
|
||||
def compute_weight(self, module, do_power_iteration):
|
||||
# NB: If `do_power_iteration` is set, the `u` and `v` vectors are
|
||||
# updated in power iteration **in-place**. This is very important
|
||||
# because in `DataParallel` forward, the vectors (being buffers) are
|
||||
# broadcast from the parallelized module to each module replica,
|
||||
# which is a new module object created on the fly. And each replica
|
||||
# runs its own spectral norm power iteration. So simply assigning
|
||||
# the updated vectors to the module this function runs on will cause
|
||||
# the update to be lost forever. And the next time the parallelized
|
||||
# module is replicated, the same randomly initialized vectors are
|
||||
# broadcast and used!
|
||||
#
|
||||
# Therefore, to make the change propagate back, we rely on two
|
||||
# important behaviors (also enforced via tests):
|
||||
# 1. `DataParallel` doesn't clone storage if the broadcast tensor
|
||||
# is already on correct device; and it makes sure that the
|
||||
# parallelized module is already on `device[0]`.
|
||||
# 2. If the out tensor in `out=` kwarg has correct shape, it will
|
||||
# just fill in the values.
|
||||
# Therefore, since the same power iteration is performed on all
|
||||
# devices, simply updating the tensors in-place will make sure that
|
||||
# the module replica on `device[0]` will update the _u vector on the
|
||||
# parallized module (by shared storage).
|
||||
#
|
||||
# However, after we update `u` and `v` in-place, we need to **clone**
|
||||
# them before using them to normalize the weight. This is to support
|
||||
# backproping through two forward passes, e.g., the common pattern in
|
||||
# GAN training: loss = D(real) - D(fake). Otherwise, engine will
|
||||
# complain that variables needed to do backward for the first forward
|
||||
# (i.e., the `u` and `v` vectors) are changed in the second forward.
|
||||
weight = getattr(module, self.name + '_orig')
|
||||
u = getattr(module, self.name + '_u')
|
||||
v = getattr(module, self.name + '_v')
|
||||
weight_mat = self.reshape_weight_to_matrix(weight)
|
||||
|
||||
if do_power_iteration:
|
||||
with torch.no_grad():
|
||||
for _ in range(self.n_power_iterations):
|
||||
# Spectral norm of weight equals to `u^T W v`, where `u` and `v`
|
||||
# are the first left and right singular vectors.
|
||||
# This power iteration produces approximations of `u` and `v`.
|
||||
v = normalize(torch.mv(weight_mat.t(), u),
|
||||
dim=0,
|
||||
eps=self.eps,
|
||||
out=v)
|
||||
u = normalize(torch.mv(weight_mat, v),
|
||||
dim=0,
|
||||
eps=self.eps,
|
||||
out=u)
|
||||
if self.n_power_iterations > 0:
|
||||
# See above on why we need to clone
|
||||
u = u.clone()
|
||||
v = v.clone()
|
||||
|
||||
sigma = torch.dot(u, torch.mv(weight_mat, v))
|
||||
weight = weight / sigma
|
||||
return weight
|
||||
|
||||
def remove(self, module):
|
||||
with torch.no_grad():
|
||||
weight = self.compute_weight(module, do_power_iteration=False)
|
||||
delattr(module, self.name)
|
||||
delattr(module, self.name + '_u')
|
||||
delattr(module, self.name + '_v')
|
||||
delattr(module, self.name + '_orig')
|
||||
module.register_parameter(self.name,
|
||||
torch.nn.Parameter(weight.detach()))
|
||||
|
||||
def __call__(self, module, inputs):
|
||||
setattr(
|
||||
module, self.name,
|
||||
self.compute_weight(module, do_power_iteration=module.training))
|
||||
|
||||
def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
|
||||
# Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
|
||||
# (the invariant at top of this class) and `u @ W @ v = sigma`.
|
||||
# This uses pinverse in case W^T W is not invertible.
|
||||
v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
|
||||
weight_mat.t(), u.unsqueeze(1)).squeeze(1)
|
||||
return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
|
||||
|
||||
@staticmethod
|
||||
def apply(module, name, n_power_iterations, dim, eps):
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||
raise RuntimeError(
|
||||
"Cannot register two spectral_norm hooks on "
|
||||
"the same parameter {}".format(name))
|
||||
|
||||
fn = SpectralNorm(name, n_power_iterations, dim, eps)
|
||||
weight = module._parameters[name]
|
||||
|
||||
with torch.no_grad():
|
||||
weight_mat = fn.reshape_weight_to_matrix(weight)
|
||||
|
||||
h, w = weight_mat.size()
|
||||
# randomly initialize `u` and `v`
|
||||
u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
|
||||
v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
|
||||
|
||||
delattr(module, fn.name)
|
||||
module.register_parameter(fn.name + "_orig", weight)
|
||||
# We still need to assign weight back as fn.name because all sorts of
|
||||
# things may assume that it exists, e.g., when initializing weights.
|
||||
# However, we can't directly assign as it could be an nn.Parameter and
|
||||
# gets added as a parameter. Instead, we register weight.data as a plain
|
||||
# attribute.
|
||||
setattr(module, fn.name, weight.data)
|
||||
module.register_buffer(fn.name + "_u", u)
|
||||
module.register_buffer(fn.name + "_v", v)
|
||||
|
||||
module.register_forward_pre_hook(fn)
|
||||
|
||||
module._register_state_dict_hook(SpectralNormStateDictHook(fn))
|
||||
module._register_load_state_dict_pre_hook(
|
||||
SpectralNormLoadStateDictPreHook(fn))
|
||||
return fn
|
||||
|
||||
|
||||
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
||||
# instancemethod.
|
||||
class SpectralNormLoadStateDictPreHook(object):
|
||||
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
# For state_dict with version None, (assuming that it has gone through at
|
||||
# least one training forward), we have
|
||||
#
|
||||
# u = normalize(W_orig @ v)
|
||||
# W = W_orig / sigma, where sigma = u @ W_orig @ v
|
||||
#
|
||||
# To compute `v`, we solve `W_orig @ x = u`, and let
|
||||
# v = x / (u @ W_orig @ x) * (W / W_orig).
|
||||
def __call__(self, state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs):
|
||||
fn = self.fn
|
||||
version = local_metadata.get('spectral_norm',
|
||||
{}).get(fn.name + '.version', None)
|
||||
if version is None or version < 1:
|
||||
with torch.no_grad():
|
||||
weight_orig = state_dict[prefix + fn.name + '_orig']
|
||||
# weight = state_dict.pop(prefix + fn.name)
|
||||
# sigma = (weight_orig / weight).mean()
|
||||
weight_mat = fn.reshape_weight_to_matrix(weight_orig)
|
||||
u = state_dict[prefix + fn.name + '_u']
|
||||
# v = fn._solve_v_and_rescale(weight_mat, u, sigma)
|
||||
# state_dict[prefix + fn.name + '_v'] = v
|
||||
|
||||
|
||||
# This is a top level class because Py2 pickle doesn't like inner class nor an
|
||||
# instancemethod.
|
||||
class SpectralNormStateDictHook(object):
|
||||
# See docstring of SpectralNorm._version on the changes to spectral_norm.
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
||||
def __call__(self, module, state_dict, prefix, local_metadata):
|
||||
if 'spectral_norm' not in local_metadata:
|
||||
local_metadata['spectral_norm'] = {}
|
||||
key = self.fn.name + '.version'
|
||||
if key in local_metadata['spectral_norm']:
|
||||
raise RuntimeError(
|
||||
"Unexpected key in metadata['spectral_norm']: {}".format(key))
|
||||
local_metadata['spectral_norm'][key] = self.fn._version
|
||||
|
||||
|
||||
def spectral_norm(module,
|
||||
name='weight',
|
||||
n_power_iterations=1,
|
||||
eps=1e-12,
|
||||
dim=None):
|
||||
r"""Applies spectral normalization to a parameter in the given module.
|
||||
|
||||
.. math::
|
||||
\mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
|
||||
\sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
|
||||
|
||||
Spectral normalization stabilizes the training of discriminators (critics)
|
||||
in Generative Adversarial Networks (GANs) by rescaling the weight tensor
|
||||
with spectral norm :math:`\sigma` of the weight matrix calculated using
|
||||
power iteration method. If the dimension of the weight tensor is greater
|
||||
than 2, it is reshaped to 2D in power iteration method to get spectral
|
||||
norm. This is implemented via a hook that calculates spectral norm and
|
||||
rescales weight before every :meth:`~Module.forward` call.
|
||||
|
||||
See `Spectral Normalization for Generative Adversarial Networks`_ .
|
||||
|
||||
.. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
|
||||
|
||||
Args:
|
||||
module (nn.Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
n_power_iterations (int, optional): number of power iterations to
|
||||
calculate spectral norm
|
||||
eps (float, optional): epsilon for numerical stability in
|
||||
calculating norms
|
||||
dim (int, optional): dimension corresponding to number of outputs,
|
||||
the default is ``0``, except for modules that are instances of
|
||||
ConvTranspose{1,2,3}d, when it is ``1``
|
||||
|
||||
Returns:
|
||||
The original module with the spectral norm hook
|
||||
|
||||
Example::
|
||||
|
||||
>>> m = spectral_norm(nn.Linear(20, 40))
|
||||
>>> m
|
||||
Linear(in_features=20, out_features=40, bias=True)
|
||||
>>> m.weight_u.size()
|
||||
torch.Size([40])
|
||||
|
||||
"""
|
||||
if dim is None:
|
||||
if isinstance(module,
|
||||
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
|
||||
torch.nn.ConvTranspose3d)):
|
||||
dim = 1
|
||||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
return module
|
||||
|
||||
|
||||
def remove_spectral_norm(module, name='weight'):
|
||||
r"""Removes the spectral normalization reparameterization from a module.
|
||||
|
||||
Args:
|
||||
module (Module): containing module
|
||||
name (str, optional): name of weight parameter
|
||||
|
||||
Example:
|
||||
>>> m = spectral_norm(nn.Linear(40, 10))
|
||||
>>> remove_spectral_norm(m)
|
||||
"""
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, SpectralNorm) and hook.name == name:
|
||||
hook.remove(module)
|
||||
del module._forward_pre_hooks[k]
|
||||
return module
|
||||
|
||||
raise ValueError("spectral_norm of '{}' not found in {}".format(
|
||||
name, module))
|
||||
|
||||
|
||||
def use_spectral_norm(module, use_sn=False):
|
||||
if use_sn:
|
||||
return spectral_norm(module)
|
||||
return module
|
||||
536
inpainter/model/modules/tfocal_transformer.py
Normal file
536
inpainter/model/modules/tfocal_transformer.py
Normal file
@@ -0,0 +1,536 @@
|
||||
"""
|
||||
This code is based on:
|
||||
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
|
||||
https://github.com/ruiliu-ai/FuseFormer
|
||||
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
|
||||
https://github.com/yitu-opensource/T2T-ViT
|
||||
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
|
||||
https://github.com/microsoft/Focal-Transformer
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SoftSplit(nn.Module):
|
||||
def __init__(self, channel, hidden, kernel_size, stride, padding,
|
||||
t2t_param):
|
||||
super(SoftSplit, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.t2t = nn.Unfold(kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
|
||||
self.embedding = nn.Linear(c_in, hidden)
|
||||
|
||||
self.f_h = int(
|
||||
(t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] -
|
||||
(t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] +
|
||||
1)
|
||||
self.f_w = int(
|
||||
(t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] -
|
||||
(t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] +
|
||||
1)
|
||||
|
||||
def forward(self, x, b):
|
||||
feat = self.t2t(x)
|
||||
feat = feat.permute(0, 2, 1)
|
||||
# feat shape [b*t, num_vec, ks*ks*c]
|
||||
feat = self.embedding(feat)
|
||||
# feat shape after embedding [b, t*num_vec, hidden]
|
||||
feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2))
|
||||
return feat
|
||||
|
||||
|
||||
class SoftComp(nn.Module):
|
||||
def __init__(self, channel, hidden, output_size, kernel_size, stride,
|
||||
padding):
|
||||
super(SoftComp, self).__init__()
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
||||
self.embedding = nn.Linear(hidden, c_out)
|
||||
self.t2t = torch.nn.Fold(output_size=output_size,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
h, w = output_size
|
||||
self.bias = nn.Parameter(torch.zeros((channel, h, w),
|
||||
dtype=torch.float32),
|
||||
requires_grad=True)
|
||||
|
||||
def forward(self, x, t):
|
||||
b_, _, _, _, c_ = x.shape
|
||||
x = x.view(b_, -1, c_)
|
||||
feat = self.embedding(x)
|
||||
b, _, c = feat.size()
|
||||
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
||||
feat = self.t2t(feat) + self.bias[None]
|
||||
return feat
|
||||
|
||||
|
||||
class FusionFeedForward(nn.Module):
|
||||
def __init__(self, d_model, n_vecs=None, t2t_params=None):
|
||||
super(FusionFeedForward, self).__init__()
|
||||
# We set d_ff as a default to 1960
|
||||
hd = 1960
|
||||
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
|
||||
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
|
||||
assert t2t_params is not None and n_vecs is not None
|
||||
tp = t2t_params.copy()
|
||||
self.fold = nn.Fold(**tp)
|
||||
del tp['output_size']
|
||||
self.unfold = nn.Unfold(**tp)
|
||||
self.n_vecs = n_vecs
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
b, n, c = x.size()
|
||||
normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs,
|
||||
49).permute(0, 2, 1)
|
||||
x = self.unfold(
|
||||
self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) /
|
||||
self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: shape is (B, T, H, W, C)
|
||||
window_size (tuple[int]): window size
|
||||
Returns:
|
||||
windows: (B*num_windows, T*window_size*window_size, C)
|
||||
"""
|
||||
B, T, H, W, C = x.shape
|
||||
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
||||
window_size[1], C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
||||
-1, T * window_size[0] * window_size[1], C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_partition_noreshape(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: shape is (B, T, H, W, C)
|
||||
window_size (tuple[int]): window size
|
||||
Returns:
|
||||
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
|
||||
"""
|
||||
B, T, H, W, C = x.shape
|
||||
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
||||
window_size[1], C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, T, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: shape is (num_windows*B, T, window_size, window_size, C)
|
||||
window_size (tuple[int]): Window size
|
||||
T (int): Temporal length of video
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, T, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
||||
x = windows.view(B, H // window_size[0], W // window_size[1], T,
|
||||
window_size[0], window_size[1], -1)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
"""Temporal focal window attention
|
||||
"""
|
||||
def __init__(self, dim, expand_size, window_size, focal_window,
|
||||
focal_level, num_heads, qkv_bias, pool_method):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.expand_size = expand_size
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pool_method = pool_method
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.focal_level = focal_level
|
||||
self.focal_window = focal_window
|
||||
|
||||
if any(i > 0 for i in self.expand_size) and focal_level > 0:
|
||||
# get mask for rolled k and rolled v
|
||||
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
|
||||
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
|
||||
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
|
||||
mask_br = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
|
||||
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
|
||||
0).flatten(0)
|
||||
self.register_buffer("valid_ind_rolled",
|
||||
mask_rolled.nonzero(as_tuple=False).view(-1))
|
||||
|
||||
if pool_method != "none" and focal_level > 1:
|
||||
self.unfolds = nn.ModuleList()
|
||||
|
||||
# build relative position bias between local patch and pooled windows
|
||||
for k in range(focal_level - 1):
|
||||
stride = 2**k
|
||||
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
|
||||
for i in self.focal_window)
|
||||
# define unfolding operations
|
||||
self.unfolds += [
|
||||
nn.Unfold(kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=tuple(i // 2 for i in kernel_size))
|
||||
]
|
||||
|
||||
# define unfolding index for focal_level > 0
|
||||
if k > 0:
|
||||
mask = torch.zeros(kernel_size)
|
||||
mask[(2**k) - 1:, (2**k) - 1:] = 1
|
||||
self.register_buffer(
|
||||
"valid_ind_unfold_{}".format(k),
|
||||
mask.flatten(0).nonzero(as_tuple=False).view(-1))
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x_all, mask_all=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (B, T, Wh, Ww, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
|
||||
|
||||
output: (nW*B, Wh*Ww, C)
|
||||
"""
|
||||
x = x_all[0]
|
||||
|
||||
B, T, nH, nW, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
|
||||
C).permute(4, 0, 1, 2, 3, 5).contiguous()
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
|
||||
|
||||
# partition q map
|
||||
(q_windows, k_windows, v_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
|
||||
contiguous().view(-1, self.num_heads, T * self.window_size[
|
||||
0] * self.window_size[1], C // self.num_heads), (q, k, v))
|
||||
# q(k/v)_windows shape : [16, 4, 225, 128]
|
||||
|
||||
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
|
||||
(k_tl, v_tl) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(-self.expand_size[0], -self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_tr, v_tr) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(-self.expand_size[0], self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_bl, v_bl) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(self.expand_size[0], -self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_br, v_br) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(self.expand_size[0], self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
|
||||
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
|
||||
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
|
||||
k_rolled = torch.cat(
|
||||
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
|
||||
2).permute(0, 3, 1, 2, 4).contiguous()
|
||||
v_rolled = torch.cat(
|
||||
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
|
||||
2).permute(0, 3, 1, 2, 4).contiguous()
|
||||
|
||||
# mask out tokens in current window
|
||||
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
|
||||
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
|
||||
temp_N = k_rolled.shape[3]
|
||||
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
|
||||
C // self.num_heads)
|
||||
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
|
||||
C // self.num_heads)
|
||||
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
||||
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
||||
else:
|
||||
k_rolled = k_windows
|
||||
v_rolled = v_windows
|
||||
|
||||
# q(k/v)_windows shape : [16, 4, 225, 128]
|
||||
# k_rolled.shape : [16, 4, 5, 165, 128]
|
||||
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
|
||||
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
|
||||
|
||||
if self.pool_method != "none" and self.focal_level > 1:
|
||||
k_pooled = []
|
||||
v_pooled = []
|
||||
for k in range(self.focal_level - 1):
|
||||
stride = 2**k
|
||||
x_window_pooled = x_all[k + 1].permute(
|
||||
0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C
|
||||
|
||||
nWh, nWw = x_window_pooled.shape[2:4]
|
||||
|
||||
# generate mask for pooled windows
|
||||
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
|
||||
# unfold mask: [nWh*nWw//s//s, k*k, 1]
|
||||
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
|
||||
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
|
||||
view(nWh*nWw // stride // stride, -1, 1)
|
||||
|
||||
if k > 0:
|
||||
valid_ind_unfold_k = getattr(
|
||||
self, "valid_ind_unfold_{}".format(k))
|
||||
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
||||
|
||||
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
||||
x_window_masks = x_window_masks.masked_fill(
|
||||
x_window_masks == 0,
|
||||
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
||||
mask_all[k + 1] = x_window_masks
|
||||
|
||||
# generate k and v for pooled windows
|
||||
qkv_pooled = self.qkv(x_window_pooled).reshape(
|
||||
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
|
||||
3).view(3, -1, C, nWh,
|
||||
nWw).contiguous()
|
||||
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[
|
||||
2] # B*T, C, nWh, nWw
|
||||
# k_pooled_k shape: [5, 512, 4, 4]
|
||||
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
|
||||
|
||||
(k_pooled_k, v_pooled_k) = map(
|
||||
lambda t: self.unfolds[k](t).view(
|
||||
B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\
|
||||
view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(),
|
||||
(k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
|
||||
)
|
||||
# k_pooled_k shape : [16, 4, 5, 45, 128]
|
||||
|
||||
# select valid unfolding index
|
||||
if k > 0:
|
||||
(k_pooled_k, v_pooled_k) = map(
|
||||
lambda t: t[:, :, :, valid_ind_unfold_k],
|
||||
(k_pooled_k, v_pooled_k))
|
||||
|
||||
k_pooled_k = k_pooled_k.view(
|
||||
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
||||
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
||||
v_pooled_k = v_pooled_k.view(
|
||||
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
||||
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
||||
|
||||
k_pooled += [k_pooled_k]
|
||||
v_pooled += [v_pooled_k]
|
||||
|
||||
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
|
||||
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
||||
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
||||
else:
|
||||
k_all = k_rolled
|
||||
v_all = v_rolled
|
||||
|
||||
N = k_all.shape[-2]
|
||||
q_windows = q_windows * self.scale
|
||||
attn = (
|
||||
q_windows @ k_all.transpose(-2, -1)
|
||||
) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
|
||||
# T * 45
|
||||
window_area = T * self.window_size[0] * self.window_size[1]
|
||||
# T * 165
|
||||
window_area_rolled = k_rolled.shape[2]
|
||||
|
||||
if self.pool_method != "none" and self.focal_level > 1:
|
||||
offset = window_area_rolled
|
||||
for k in range(self.focal_level - 1):
|
||||
# add attentional mask
|
||||
# mask_all[1] shape [1, 16, T * 45]
|
||||
|
||||
bias = tuple((i + 2**k - 1) for i in self.focal_window)
|
||||
|
||||
if mask_all[k + 1] is not None:
|
||||
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
|
||||
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
|
||||
mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
||||
|
||||
offset += T * bias[0] * bias[1]
|
||||
|
||||
if mask_all[0] is not None:
|
||||
nW = mask_all[0].shape[0]
|
||||
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
|
||||
window_area, N)
|
||||
attn[:, :, :, :, :
|
||||
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
|
||||
None, :, None, :, :]
|
||||
attn = attn.view(-1, self.num_heads, window_area, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
|
||||
C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class TemporalFocalTransformerBlock(nn.Module):
|
||||
r""" Temporal Focal Transformer Block.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
focal_level (int): The number level of focal window.
|
||||
focal_window (int): Window size of each focal window.
|
||||
n_vecs (int): Required for F3N.
|
||||
t2t_params (int): T2T parameters for F3N.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(5, 9),
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
pool_method="fc",
|
||||
focal_level=2,
|
||||
focal_window=(5, 9),
|
||||
norm_layer=nn.LayerNorm,
|
||||
n_vecs=None,
|
||||
t2t_params=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.expand_size = tuple(i // 2 for i in window_size) # TODO
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pool_method = pool_method
|
||||
self.focal_level = focal_level
|
||||
self.focal_window = focal_window
|
||||
|
||||
self.window_size_glo = self.window_size
|
||||
|
||||
self.pool_layers = nn.ModuleList()
|
||||
if self.pool_method != "none":
|
||||
for k in range(self.focal_level - 1):
|
||||
window_size_glo = tuple(
|
||||
math.floor(i / (2**k)) for i in self.window_size_glo)
|
||||
self.pool_layers.append(
|
||||
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
|
||||
self.pool_layers[-1].weight.data.fill_(
|
||||
1. / (window_size_glo[0] * window_size_glo[1]))
|
||||
self.pool_layers[-1].bias.data.fill_(0)
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = WindowAttention(dim,
|
||||
expand_size=self.expand_size,
|
||||
window_size=self.window_size,
|
||||
focal_window=focal_window,
|
||||
focal_level=focal_level,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
pool_method=pool_method)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
shifted_x = x
|
||||
|
||||
x_windows_all = [shifted_x]
|
||||
x_window_masks_all = [None]
|
||||
|
||||
# partition windows tuple(i // 2 for i in window_size)
|
||||
if self.focal_level > 1 and self.pool_method != "none":
|
||||
# if we add coarser granularity and the pool method is not none
|
||||
for k in range(self.focal_level - 1):
|
||||
window_size_glo = tuple(
|
||||
math.floor(i / (2**k)) for i in self.window_size_glo)
|
||||
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
|
||||
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
|
||||
H_pool = pooled_h * window_size_glo[0]
|
||||
W_pool = pooled_w * window_size_glo[1]
|
||||
|
||||
x_level_k = shifted_x
|
||||
# trim or pad shifted_x depending on the required size
|
||||
if H > H_pool:
|
||||
trim_t = (H - H_pool) // 2
|
||||
trim_b = H - H_pool - trim_t
|
||||
x_level_k = x_level_k[:, :, trim_t:-trim_b]
|
||||
elif H < H_pool:
|
||||
pad_t = (H_pool - H) // 2
|
||||
pad_b = H_pool - H - pad_t
|
||||
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
|
||||
|
||||
if W > W_pool:
|
||||
trim_l = (W - W_pool) // 2
|
||||
trim_r = W - W_pool - trim_l
|
||||
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
|
||||
elif W < W_pool:
|
||||
pad_l = (W_pool - W) // 2
|
||||
pad_r = W_pool - W - pad_l
|
||||
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
|
||||
|
||||
x_windows_noreshape = window_partition_noreshape(
|
||||
x_level_k.contiguous(), window_size_glo
|
||||
) # B, nw, nw, T, window_size, window_size, C
|
||||
nWh, nWw = x_windows_noreshape.shape[1:3]
|
||||
x_windows_noreshape = x_windows_noreshape.view(
|
||||
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
|
||||
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
|
||||
x_windows_pooled = self.pool_layers[k](
|
||||
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
|
||||
|
||||
x_windows_all += [x_windows_pooled]
|
||||
x_window_masks_all += [None]
|
||||
|
||||
attn_windows = self.attn(
|
||||
x_windows_all,
|
||||
mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, T, self.window_size[0],
|
||||
self.window_size[1], C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
|
||||
W) # B T H' W' C
|
||||
|
||||
# FFN
|
||||
x = shortcut + shifted_x
|
||||
y = self.norm2(x)
|
||||
x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C)
|
||||
|
||||
return x
|
||||
565
inpainter/model/modules/tfocal_transformer_hq.py
Normal file
565
inpainter/model/modules/tfocal_transformer_hq.py
Normal file
@@ -0,0 +1,565 @@
|
||||
"""
|
||||
This code is based on:
|
||||
[1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021
|
||||
https://github.com/ruiliu-ai/FuseFormer
|
||||
[2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021
|
||||
https://github.com/yitu-opensource/T2T-ViT
|
||||
[3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021
|
||||
https://github.com/microsoft/Focal-Transformer
|
||||
"""
|
||||
|
||||
import math
|
||||
from functools import reduce
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class SoftSplit(nn.Module):
|
||||
def __init__(self, channel, hidden, kernel_size, stride, padding,
|
||||
t2t_param):
|
||||
super(SoftSplit, self).__init__()
|
||||
self.kernel_size = kernel_size
|
||||
self.t2t = nn.Unfold(kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
c_in = reduce((lambda x, y: x * y), kernel_size) * channel
|
||||
self.embedding = nn.Linear(c_in, hidden)
|
||||
|
||||
self.t2t_param = t2t_param
|
||||
|
||||
def forward(self, x, b, output_size):
|
||||
f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] -
|
||||
(self.t2t_param['kernel_size'][0] - 1) - 1) /
|
||||
self.t2t_param['stride'][0] + 1)
|
||||
f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] -
|
||||
(self.t2t_param['kernel_size'][1] - 1) - 1) /
|
||||
self.t2t_param['stride'][1] + 1)
|
||||
|
||||
feat = self.t2t(x)
|
||||
feat = feat.permute(0, 2, 1)
|
||||
# feat shape [b*t, num_vec, ks*ks*c]
|
||||
feat = self.embedding(feat)
|
||||
# feat shape after embedding [b, t*num_vec, hidden]
|
||||
feat = feat.view(b, -1, f_h, f_w, feat.size(2))
|
||||
return feat
|
||||
|
||||
|
||||
class SoftComp(nn.Module):
|
||||
def __init__(self, channel, hidden, kernel_size, stride, padding):
|
||||
super(SoftComp, self).__init__()
|
||||
self.relu = nn.LeakyReLU(0.2, inplace=True)
|
||||
c_out = reduce((lambda x, y: x * y), kernel_size) * channel
|
||||
self.embedding = nn.Linear(hidden, c_out)
|
||||
self.kernel_size = kernel_size
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.bias_conv = nn.Conv2d(channel,
|
||||
channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
# TODO upsample conv
|
||||
# self.bias_conv = nn.Conv2d()
|
||||
# self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True)
|
||||
|
||||
def forward(self, x, t, output_size):
|
||||
b_, _, _, _, c_ = x.shape
|
||||
x = x.view(b_, -1, c_)
|
||||
feat = self.embedding(x)
|
||||
b, _, c = feat.size()
|
||||
feat = feat.view(b * t, -1, c).permute(0, 2, 1)
|
||||
feat = F.fold(feat,
|
||||
output_size=output_size,
|
||||
kernel_size=self.kernel_size,
|
||||
stride=self.stride,
|
||||
padding=self.padding)
|
||||
feat = self.bias_conv(feat)
|
||||
return feat
|
||||
|
||||
|
||||
class FusionFeedForward(nn.Module):
|
||||
def __init__(self, d_model, n_vecs=None, t2t_params=None):
|
||||
super(FusionFeedForward, self).__init__()
|
||||
# We set d_ff as a default to 1960
|
||||
hd = 1960
|
||||
self.conv1 = nn.Sequential(nn.Linear(d_model, hd))
|
||||
self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model))
|
||||
assert t2t_params is not None and n_vecs is not None
|
||||
self.t2t_params = t2t_params
|
||||
|
||||
def forward(self, x, output_size):
|
||||
n_vecs = 1
|
||||
for i, d in enumerate(self.t2t_params['kernel_size']):
|
||||
n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
|
||||
(d - 1) - 1) / self.t2t_params['stride'][i] + 1)
|
||||
|
||||
x = self.conv1(x)
|
||||
b, n, c = x.size()
|
||||
normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1)
|
||||
normalizer = F.fold(normalizer,
|
||||
output_size=output_size,
|
||||
kernel_size=self.t2t_params['kernel_size'],
|
||||
padding=self.t2t_params['padding'],
|
||||
stride=self.t2t_params['stride'])
|
||||
|
||||
x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
|
||||
output_size=output_size,
|
||||
kernel_size=self.t2t_params['kernel_size'],
|
||||
padding=self.t2t_params['padding'],
|
||||
stride=self.t2t_params['stride'])
|
||||
|
||||
x = F.unfold(x / normalizer,
|
||||
kernel_size=self.t2t_params['kernel_size'],
|
||||
padding=self.t2t_params['padding'],
|
||||
stride=self.t2t_params['stride']).permute(
|
||||
0, 2, 1).contiguous().view(b, n, c)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
def window_partition(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: shape is (B, T, H, W, C)
|
||||
window_size (tuple[int]): window size
|
||||
Returns:
|
||||
windows: (B*num_windows, T*window_size*window_size, C)
|
||||
"""
|
||||
B, T, H, W, C = x.shape
|
||||
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
||||
window_size[1], C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view(
|
||||
-1, T * window_size[0] * window_size[1], C)
|
||||
return windows
|
||||
|
||||
|
||||
def window_partition_noreshape(x, window_size):
|
||||
"""
|
||||
Args:
|
||||
x: shape is (B, T, H, W, C)
|
||||
window_size (tuple[int]): window size
|
||||
Returns:
|
||||
windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C)
|
||||
"""
|
||||
B, T, H, W, C = x.shape
|
||||
x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1],
|
||||
window_size[1], C)
|
||||
windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous()
|
||||
return windows
|
||||
|
||||
|
||||
def window_reverse(windows, window_size, T, H, W):
|
||||
"""
|
||||
Args:
|
||||
windows: shape is (num_windows*B, T, window_size, window_size, C)
|
||||
window_size (tuple[int]): Window size
|
||||
T (int): Temporal length of video
|
||||
H (int): Height of image
|
||||
W (int): Width of image
|
||||
Returns:
|
||||
x: (B, T, H, W, C)
|
||||
"""
|
||||
B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
|
||||
x = windows.view(B, H // window_size[0], W // window_size[1], T,
|
||||
window_size[0], window_size[1], -1)
|
||||
x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1)
|
||||
return x
|
||||
|
||||
|
||||
class WindowAttention(nn.Module):
|
||||
"""Temporal focal window attention
|
||||
"""
|
||||
def __init__(self, dim, expand_size, window_size, focal_window,
|
||||
focal_level, num_heads, qkv_bias, pool_method):
|
||||
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.expand_size = expand_size
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pool_method = pool_method
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
self.focal_level = focal_level
|
||||
self.focal_window = focal_window
|
||||
|
||||
if any(i > 0 for i in self.expand_size) and focal_level > 0:
|
||||
# get mask for rolled k and rolled v
|
||||
mask_tl = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
|
||||
mask_tr = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
|
||||
mask_bl = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
|
||||
mask_br = torch.ones(self.window_size[0], self.window_size[1])
|
||||
mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
|
||||
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br),
|
||||
0).flatten(0)
|
||||
self.register_buffer("valid_ind_rolled",
|
||||
mask_rolled.nonzero(as_tuple=False).view(-1))
|
||||
|
||||
if pool_method != "none" and focal_level > 1:
|
||||
self.unfolds = nn.ModuleList()
|
||||
|
||||
# build relative position bias between local patch and pooled windows
|
||||
for k in range(focal_level - 1):
|
||||
stride = 2**k
|
||||
kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1)
|
||||
for i in self.focal_window)
|
||||
# define unfolding operations
|
||||
self.unfolds += [
|
||||
nn.Unfold(kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=tuple(i // 2 for i in kernel_size))
|
||||
]
|
||||
|
||||
# define unfolding index for focal_level > 0
|
||||
if k > 0:
|
||||
mask = torch.zeros(kernel_size)
|
||||
mask[(2**k) - 1:, (2**k) - 1:] = 1
|
||||
self.register_buffer(
|
||||
"valid_ind_unfold_{}".format(k),
|
||||
mask.flatten(0).nonzero(as_tuple=False).view(-1))
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
|
||||
self.softmax = nn.Softmax(dim=-1)
|
||||
|
||||
def forward(self, x_all, mask_all=None):
|
||||
"""
|
||||
Args:
|
||||
x: input features with shape of (B, T, Wh, Ww, C)
|
||||
mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None
|
||||
|
||||
output: (nW*B, Wh*Ww, C)
|
||||
"""
|
||||
x = x_all[0]
|
||||
|
||||
B, T, nH, nW, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, T, nH, nW, 3,
|
||||
C).permute(4, 0, 1, 2, 3, 5).contiguous()
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C
|
||||
|
||||
# partition q map
|
||||
(q_windows, k_windows, v_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).
|
||||
contiguous().view(-1, self.num_heads, T * self.window_size[
|
||||
0] * self.window_size[1], C // self.num_heads), (q, k, v))
|
||||
# q(k/v)_windows shape : [16, 4, 225, 128]
|
||||
|
||||
if any(i > 0 for i in self.expand_size) and self.focal_level > 0:
|
||||
(k_tl, v_tl) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(-self.expand_size[0], -self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_tr, v_tr) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(-self.expand_size[0], self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_bl, v_bl) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(self.expand_size[0], -self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
(k_br, v_br) = map(
|
||||
lambda t: torch.roll(t,
|
||||
shifts=(self.expand_size[0], self.
|
||||
expand_size[1]),
|
||||
dims=(2, 3)), (k, v))
|
||||
|
||||
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br))
|
||||
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
||||
lambda t: window_partition(t, self.window_size).view(
|
||||
-1, T, self.window_size[0] * self.window_size[1], self.
|
||||
num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br))
|
||||
k_rolled = torch.cat(
|
||||
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows),
|
||||
2).permute(0, 3, 1, 2, 4).contiguous()
|
||||
v_rolled = torch.cat(
|
||||
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows),
|
||||
2).permute(0, 3, 1, 2, 4).contiguous()
|
||||
|
||||
# mask out tokens in current window
|
||||
k_rolled = k_rolled[:, :, :, self.valid_ind_rolled]
|
||||
v_rolled = v_rolled[:, :, :, self.valid_ind_rolled]
|
||||
temp_N = k_rolled.shape[3]
|
||||
k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N,
|
||||
C // self.num_heads)
|
||||
v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N,
|
||||
C // self.num_heads)
|
||||
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
||||
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
||||
else:
|
||||
k_rolled = k_windows
|
||||
v_rolled = v_windows
|
||||
|
||||
# q(k/v)_windows shape : [16, 4, 225, 128]
|
||||
# k_rolled.shape : [16, 4, 5, 165, 128]
|
||||
# ideal expanded window size 153 ((5+2*2)*(9+2*4))
|
||||
# k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2)
|
||||
|
||||
if self.pool_method != "none" and self.focal_level > 1:
|
||||
k_pooled = []
|
||||
v_pooled = []
|
||||
for k in range(self.focal_level - 1):
|
||||
stride = 2**k
|
||||
# B, T, nWh, nWw, C
|
||||
x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2,
|
||||
4).contiguous()
|
||||
|
||||
nWh, nWw = x_window_pooled.shape[2:4]
|
||||
|
||||
# generate mask for pooled windows
|
||||
mask = x_window_pooled.new(T, nWh, nWw).fill_(1)
|
||||
# unfold mask: [nWh*nWw//s//s, k*k, 1]
|
||||
unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view(
|
||||
1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\
|
||||
view(nWh*nWw // stride // stride, -1, 1)
|
||||
|
||||
if k > 0:
|
||||
valid_ind_unfold_k = getattr(
|
||||
self, "valid_ind_unfold_{}".format(k))
|
||||
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
||||
|
||||
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
||||
x_window_masks = x_window_masks.masked_fill(
|
||||
x_window_masks == 0,
|
||||
float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
||||
mask_all[k + 1] = x_window_masks
|
||||
|
||||
# generate k and v for pooled windows
|
||||
qkv_pooled = self.qkv(x_window_pooled).reshape(
|
||||
B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2,
|
||||
3).view(3, -1, C, nWh,
|
||||
nWw).contiguous()
|
||||
# B*T, C, nWh, nWw
|
||||
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2]
|
||||
# k_pooled_k shape: [5, 512, 4, 4]
|
||||
# self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16]
|
||||
|
||||
(k_pooled_k, v_pooled_k) = map(
|
||||
lambda t: self.unfolds[k]
|
||||
(t).view(B, T, C, self.unfolds[k].kernel_size[0], self.
|
||||
unfolds[k].kernel_size[1], -1)
|
||||
.permute(0, 5, 1, 3, 4, 2).contiguous().view(
|
||||
-1, T, self.unfolds[k].kernel_size[0] * self.unfolds[
|
||||
k].kernel_size[1], self.num_heads, C // self.
|
||||
num_heads).permute(0, 3, 1, 2, 4).contiguous(),
|
||||
# (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim
|
||||
(k_pooled_k, v_pooled_k))
|
||||
# k_pooled_k shape : [16, 4, 5, 45, 128]
|
||||
|
||||
# select valid unfolding index
|
||||
if k > 0:
|
||||
(k_pooled_k, v_pooled_k) = map(
|
||||
lambda t: t[:, :, :, valid_ind_unfold_k],
|
||||
(k_pooled_k, v_pooled_k))
|
||||
|
||||
k_pooled_k = k_pooled_k.view(
|
||||
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
||||
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
||||
v_pooled_k = v_pooled_k.view(
|
||||
-1, self.num_heads, T * self.unfolds[k].kernel_size[0] *
|
||||
self.unfolds[k].kernel_size[1], C // self.num_heads)
|
||||
|
||||
k_pooled += [k_pooled_k]
|
||||
v_pooled += [v_pooled_k]
|
||||
|
||||
# k_all (v_all) shape : [16, 4, 5 * 210, 128]
|
||||
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
||||
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
||||
else:
|
||||
k_all = k_rolled
|
||||
v_all = v_rolled
|
||||
|
||||
N = k_all.shape[-2]
|
||||
q_windows = q_windows * self.scale
|
||||
# B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size
|
||||
attn = (q_windows @ k_all.transpose(-2, -1))
|
||||
# T * 45
|
||||
window_area = T * self.window_size[0] * self.window_size[1]
|
||||
# T * 165
|
||||
window_area_rolled = k_rolled.shape[2]
|
||||
|
||||
if self.pool_method != "none" and self.focal_level > 1:
|
||||
offset = window_area_rolled
|
||||
for k in range(self.focal_level - 1):
|
||||
# add attentional mask
|
||||
# mask_all[1] shape [1, 16, T * 45]
|
||||
|
||||
bias = tuple((i + 2**k - 1) for i in self.focal_window)
|
||||
|
||||
if mask_all[k + 1] is not None:
|
||||
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \
|
||||
attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \
|
||||
mask_all[k+1][:, :, None, None, :].repeat(
|
||||
attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
||||
|
||||
offset += T * bias[0] * bias[1]
|
||||
|
||||
if mask_all[0] is not None:
|
||||
nW = mask_all[0].shape[0]
|
||||
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads,
|
||||
window_area, N)
|
||||
attn[:, :, :, :, :
|
||||
window_area] = attn[:, :, :, :, :window_area] + mask_all[0][
|
||||
None, :, None, :, :]
|
||||
attn = attn.view(-1, self.num_heads, window_area, N)
|
||||
attn = self.softmax(attn)
|
||||
else:
|
||||
attn = self.softmax(attn)
|
||||
|
||||
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area,
|
||||
C)
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class TemporalFocalTransformerBlock(nn.Module):
|
||||
r""" Temporal Focal Transformer Block.
|
||||
Args:
|
||||
dim (int): Number of input channels.
|
||||
num_heads (int): Number of attention heads.
|
||||
window_size (tuple[int]): Window size.
|
||||
shift_size (int): Shift size for SW-MSA.
|
||||
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
||||
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
||||
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
||||
focal_level (int): The number level of focal window.
|
||||
focal_window (int): Window size of each focal window.
|
||||
n_vecs (int): Required for F3N.
|
||||
t2t_params (int): T2T parameters for F3N.
|
||||
"""
|
||||
def __init__(self,
|
||||
dim,
|
||||
num_heads,
|
||||
window_size=(5, 9),
|
||||
mlp_ratio=4.,
|
||||
qkv_bias=True,
|
||||
pool_method="fc",
|
||||
focal_level=2,
|
||||
focal_window=(5, 9),
|
||||
norm_layer=nn.LayerNorm,
|
||||
n_vecs=None,
|
||||
t2t_params=None):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
self.num_heads = num_heads
|
||||
self.window_size = window_size
|
||||
self.expand_size = tuple(i // 2 for i in window_size) # TODO
|
||||
self.mlp_ratio = mlp_ratio
|
||||
self.pool_method = pool_method
|
||||
self.focal_level = focal_level
|
||||
self.focal_window = focal_window
|
||||
|
||||
self.window_size_glo = self.window_size
|
||||
|
||||
self.pool_layers = nn.ModuleList()
|
||||
if self.pool_method != "none":
|
||||
for k in range(self.focal_level - 1):
|
||||
window_size_glo = tuple(
|
||||
math.floor(i / (2**k)) for i in self.window_size_glo)
|
||||
self.pool_layers.append(
|
||||
nn.Linear(window_size_glo[0] * window_size_glo[1], 1))
|
||||
self.pool_layers[-1].weight.data.fill_(
|
||||
1. / (window_size_glo[0] * window_size_glo[1]))
|
||||
self.pool_layers[-1].bias.data.fill_(0)
|
||||
|
||||
self.norm1 = norm_layer(dim)
|
||||
|
||||
self.attn = WindowAttention(dim,
|
||||
expand_size=self.expand_size,
|
||||
window_size=self.window_size,
|
||||
focal_window=focal_window,
|
||||
focal_level=focal_level,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
pool_method=pool_method)
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params)
|
||||
|
||||
def forward(self, x):
|
||||
output_size = x[1]
|
||||
x = x[0]
|
||||
|
||||
B, T, H, W, C = x.shape
|
||||
|
||||
shortcut = x
|
||||
x = self.norm1(x)
|
||||
|
||||
shifted_x = x
|
||||
|
||||
x_windows_all = [shifted_x]
|
||||
x_window_masks_all = [None]
|
||||
|
||||
# partition windows tuple(i // 2 for i in window_size)
|
||||
if self.focal_level > 1 and self.pool_method != "none":
|
||||
# if we add coarser granularity and the pool method is not none
|
||||
for k in range(self.focal_level - 1):
|
||||
window_size_glo = tuple(
|
||||
math.floor(i / (2**k)) for i in self.window_size_glo)
|
||||
pooled_h = math.ceil(H / window_size_glo[0]) * (2**k)
|
||||
pooled_w = math.ceil(W / window_size_glo[1]) * (2**k)
|
||||
H_pool = pooled_h * window_size_glo[0]
|
||||
W_pool = pooled_w * window_size_glo[1]
|
||||
|
||||
x_level_k = shifted_x
|
||||
# trim or pad shifted_x depending on the required size
|
||||
if H > H_pool:
|
||||
trim_t = (H - H_pool) // 2
|
||||
trim_b = H - H_pool - trim_t
|
||||
x_level_k = x_level_k[:, :, trim_t:-trim_b]
|
||||
elif H < H_pool:
|
||||
pad_t = (H_pool - H) // 2
|
||||
pad_b = H_pool - H - pad_t
|
||||
x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
|
||||
|
||||
if W > W_pool:
|
||||
trim_l = (W - W_pool) // 2
|
||||
trim_r = W - W_pool - trim_l
|
||||
x_level_k = x_level_k[:, :, :, trim_l:-trim_r]
|
||||
elif W < W_pool:
|
||||
pad_l = (W_pool - W) // 2
|
||||
pad_r = W_pool - W - pad_l
|
||||
x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
|
||||
|
||||
x_windows_noreshape = window_partition_noreshape(
|
||||
x_level_k.contiguous(), window_size_glo
|
||||
) # B, nw, nw, T, window_size, window_size, C
|
||||
nWh, nWw = x_windows_noreshape.shape[1:3]
|
||||
x_windows_noreshape = x_windows_noreshape.view(
|
||||
B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1],
|
||||
C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2
|
||||
x_windows_pooled = self.pool_layers[k](
|
||||
x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C
|
||||
|
||||
x_windows_all += [x_windows_pooled]
|
||||
x_window_masks_all += [None]
|
||||
|
||||
# nW*B, T*window_size*window_size, C
|
||||
attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all)
|
||||
|
||||
# merge windows
|
||||
attn_windows = attn_windows.view(-1, T, self.window_size[0],
|
||||
self.window_size[1], C)
|
||||
shifted_x = window_reverse(attn_windows, self.window_size, T, H,
|
||||
W) # B T H' W' C
|
||||
|
||||
# FFN
|
||||
x = shortcut + shifted_x
|
||||
y = self.norm2(x)
|
||||
x = x + self.mlp(y.view(B, T * H * W, C), output_size).view(
|
||||
B, T, H, W, C)
|
||||
|
||||
return x, output_size
|
||||
0
inpainter/util/__init__.py
Normal file
0
inpainter/util/__init__.py
Normal file
24
inpainter/util/tensor_util.py
Normal file
24
inpainter/util/tensor_util.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# resize frames
|
||||
def resize_frames(frames, size=None):
|
||||
"""
|
||||
size: (w, h)
|
||||
"""
|
||||
if size is not None:
|
||||
frames = [cv2.resize(f, size) for f in frames]
|
||||
frames = np.stack(frames, 0)
|
||||
|
||||
return frames
|
||||
|
||||
# resize frames
|
||||
def resize_masks(masks, size=None):
|
||||
"""
|
||||
size: (w, h)
|
||||
"""
|
||||
if size is not None:
|
||||
masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks]
|
||||
masks = np.stack(masks, 0)
|
||||
|
||||
return masks
|
||||
Reference in New Issue
Block a user