diff --git a/inpainter/base_inpainter.py b/inpainter/base_inpainter.py new file mode 100644 index 0000000..f3a473e --- /dev/null +++ b/inpainter/base_inpainter.py @@ -0,0 +1,160 @@ +import os +import glob +from PIL import Image + +import torch +import yaml +import cv2 +import importlib +import numpy as np +from util.tensor_util import resize_frames, resize_masks + + +class BaseInpainter: + def __init__(self, E2FGVI_checkpoint, device) -> None: + """ + E2FGVI_checkpoint: checkpoint of inpainter (version hq, with multi-resolution support) + """ + net = importlib.import_module('model.e2fgvi_hq') + self.model = net.InpaintGenerator().to(device) + self.model.load_state_dict(torch.load(E2FGVI_checkpoint, map_location=device)) + self.model.eval() + self.device = device + # load configurations + with open("inpainter/config/config.yaml", 'r') as stream: + config = yaml.safe_load(stream) + self.neighbor_stride = config['neighbor_stride'] + self.num_ref = config['num_ref'] + self.step = config['step'] + + # sample reference frames from the whole video + def get_ref_index(self, f, neighbor_ids, length): + ref_index = [] + if self.num_ref == -1: + for i in range(0, length, self.step): + if i not in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, f - self.step * (self.num_ref // 2)) + end_idx = min(length, f + self.step * (self.num_ref // 2)) + for i in range(start_idx, end_idx + 1, self.step): + if i not in neighbor_ids: + if len(ref_index) > self.num_ref: + break + ref_index.append(i) + return ref_index + + def inpaint(self, frames, masks, dilate_radius=15, ratio=1): + """ + frames: numpy array, T, H, W, 3 + masks: numpy array, T, H, W + dilate_radius: radius when applying dilation on masks + ratio: down-sample ratio + + Output: + inpainted_frames: numpy array, T, H, W, 3 + """ + assert frames.shape[:3] == masks.shape, 'different size between frames and masks' + assert ratio > 0 and ratio <= 1, 'ratio must in (0, 1]' + masks = masks.copy() + masks = np.clip(masks, 0, 1) + kernel = cv2.getStructuringElement(2, (dilate_radius, dilate_radius)) + masks = np.stack([cv2.dilate(mask, kernel) for mask in masks], 0) + + T, H, W = masks.shape + # size: (w, h) + if ratio == 1: + size = None + else: + size = (int(W*ratio), int(H*ratio)) + + masks = np.expand_dims(masks, axis=3) # expand to T, H, W, 1 + binary_masks = resize_masks(masks, size) + frames = resize_frames(frames, size) # T, H, W, 3 + # frames and binary_masks are numpy arrays + + h, w = frames.shape[1:3] + video_length = T + + # convert to tensor + imgs = (torch.from_numpy(frames).permute(0, 3, 1, 2).contiguous().unsqueeze(0).float().div(255)) * 2 - 1 + masks = torch.from_numpy(binary_masks).permute(0, 3, 1, 2).contiguous().unsqueeze(0) + + imgs, masks = imgs.to(self.device), masks.to(self.device) + comp_frames = [None] * video_length + + for f in range(0, video_length, self.neighbor_stride): + neighbor_ids = [ + i for i in range(max(0, f - self.neighbor_stride), + min(video_length, f + self.neighbor_stride + 1)) + ] + ref_ids = self.get_ref_index(f, neighbor_ids, video_length) + selected_imgs = imgs[:1, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks[:1, neighbor_ids + ref_ids, :, :, :] + with torch.no_grad(): + masked_imgs = selected_imgs * (1 - selected_masks) + mod_size_h = 60 + mod_size_w = 108 + h_pad = (mod_size_h - h % mod_size_h) % mod_size_h + w_pad = (mod_size_w - w % mod_size_w) % mod_size_w + masked_imgs = torch.cat( + [masked_imgs, torch.flip(masked_imgs, [3])], + 3)[:, :, :, :h + h_pad, :] + masked_imgs = torch.cat( + [masked_imgs, torch.flip(masked_imgs, [4])], + 4)[:, :, :, :, :w + w_pad] + pred_imgs, _ = self.model(masked_imgs, len(neighbor_ids)) + pred_imgs = pred_imgs[:, :, :h, :w] + pred_imgs = (pred_imgs + 1) / 2 + pred_imgs = pred_imgs.cpu().permute(0, 2, 3, 1).numpy() * 255 + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = pred_imgs[i].astype(np.uint8) * binary_masks[idx] + frames[idx] * ( + 1 - binary_masks[idx]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype( + np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + inpainted_frames = np.stack(comp_frames, 0) + return inpainted_frames.astype(np.uint8) + +if __name__ == '__main__': + + frame_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/JPEGImages/480p/parkour', '*.jpg')) + frame_path.sort() + mask_path = glob.glob(os.path.join('/ssd1/gaomingqi/datasets/davis/Annotations/480p/parkour', "*.png")) + mask_path.sort() + save_path = '/ssd1/gaomingqi/results/inpainting/parkour' + + if not os.path.exists(save_path): + os.mkdir(save_path) + + frames = [] + masks = [] + for fid, mid in zip(frame_path, mask_path): + frames.append(Image.open(fid).convert('RGB')) + masks.append(Image.open(mid).convert('P')) + + frames = np.stack(frames, 0) + masks = np.stack(masks, 0) + + # ---------------------------------------------- + # how to use + # ---------------------------------------------- + # 1/3: set checkpoint and device + checkpoint = '/ssd1/gaomingqi/checkpoints/E2FGVI-HQ-CVPR22.pth' + device = 'cuda:2' + # 2/3: initialise inpainter + base_inpainter = BaseInpainter(checkpoint, device) + # 3/3: inpainting (frames: numpy array, T, H, W, 3; masks: numpy array, T, H, W) + # ratio: (0, 1], ratio for down sample, default value is 1 + inpainted_frames = base_inpainter.inpaint(frames, masks, ratio=0.5) # numpy array, T, H, W, 3 + # ---------------------------------------------- + # end + # ---------------------------------------------- + # save + for ti, inpainted_frame in enumerate(inpainted_frames): + frame = Image.fromarray(inpainted_frame).convert('RGB') + frame.save(os.path.join(save_path, f'{ti:05d}.jpg')) diff --git a/inpainter/config/config.yaml b/inpainter/config/config.yaml new file mode 100644 index 0000000..ef4c180 --- /dev/null +++ b/inpainter/config/config.yaml @@ -0,0 +1,4 @@ +# config info for E2FGVI +neighbor_stride: 5 +num_ref: -1 +step: 10 diff --git a/inpainter/model/e2fgvi.py b/inpainter/model/e2fgvi.py new file mode 100644 index 0000000..ea90b61 --- /dev/null +++ b/inpainter/model/e2fgvi.py @@ -0,0 +1,350 @@ +''' Towards An End-to-End Framework for Video Inpainting +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.modules.flow_comp import SPyNet +from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment +from model.modules.tfocal_transformer import TemporalFocalTransformerBlock, SoftSplit, SoftComp +from model.modules.spectral_norm import spectral_norm as _spectral_norm + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % + init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.group = [1, 2, 4, 8, 1] + self.layers = nn.ModuleList([ + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True) + ]) + + def forward(self, x): + bt, c, h, w = x.size() + h, w = h // 4, w // 4 + out = x + for i, layer in enumerate(self.layers): + if i == 8: + x0 = out + if i > 8 and i % 2 == 0: + g = self.group[(i - 8) // 2] + x = x0.view(bt, g, -1, h, w) + o = out.view(bt, g, -1, h, w) + out = torch.cat([x, o], 2).view(bt, -1, h, w) + out = layer(out) + return out + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + hidden = 512 + + # encoder + self.encoder = Encoder() + + # decoder + self.decoder = nn.Sequential( + deconv(channel // 2, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + # feature propagation module + self.feat_prop_module = BidirectionalPropagation(channel // 2) + + # soft split and soft composition + kernel_size = (7, 7) + padding = (3, 3) + stride = (3, 3) + output_size = (60, 108) + t2t_params = { + 'kernel_size': kernel_size, + 'stride': stride, + 'padding': padding, + 'output_size': output_size + } + self.ss = SoftSplit(channel // 2, + hidden, + kernel_size, + stride, + padding, + t2t_param=t2t_params) + self.sc = SoftComp(channel // 2, hidden, output_size, kernel_size, + stride, padding) + + n_vecs = 1 + for i, d in enumerate(kernel_size): + n_vecs *= int((output_size[i] + 2 * padding[i] - + (d - 1) - 1) / stride[i] + 1) + + blocks = [] + depths = 8 + num_heads = [4] * depths + window_size = [(5, 9)] * depths + focal_windows = [(5, 9)] * depths + focal_levels = [2] * depths + pool_method = "fc" + + for i in range(depths): + blocks.append( + TemporalFocalTransformerBlock(dim=hidden, + num_heads=num_heads[i], + window_size=window_size[i], + focal_level=focal_levels[i], + focal_window=focal_windows[i], + n_vecs=n_vecs, + t2t_params=t2t_params, + pool_method=pool_method)) + self.transformer = nn.Sequential(*blocks) + + if init_weights: + self.init_weights() + # Need to initial the weights of MSDeformAttn specifically + for m in self.modules(): + if isinstance(m, SecondOrderDeformableAlignment): + m.init_offset() + + # flow completion network + self.update_spynet = SPyNet() + + def forward_bidirect_flow(self, masked_local_frames): + b, l_t, c, h, w = masked_local_frames.size() + + # compute forward and backward flows of masked frames + masked_local_frames = F.interpolate(masked_local_frames.view( + -1, c, h, w), + scale_factor=1 / 4, + mode='bilinear', + align_corners=True, + recompute_scale_factor=True) + masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4, + w // 4) + mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape( + -1, c, h // 4, w // 4) + mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape( + -1, c, h // 4, w // 4) + pred_flows_forward = self.update_spynet(mlf_1, mlf_2) + pred_flows_backward = self.update_spynet(mlf_2, mlf_1) + + pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4, + w // 4) + pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4, + w // 4) + + return pred_flows_forward, pred_flows_backward + + def forward(self, masked_frames, num_local_frames): + l_t = num_local_frames + b, t, ori_c, ori_h, ori_w = masked_frames.size() + + # normalization before feeding into the flow completion module + masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2 + pred_flows = self.forward_bidirect_flow(masked_local_frames) + + # extracting features and performing the feature propagation on local features + enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w)) + _, c, h, w = enc_feat.size() + local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] + ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] + local_feat = self.feat_prop_module(local_feat, pred_flows[0], + pred_flows[1]) + enc_feat = torch.cat((local_feat, ref_feat), dim=1) + + # content hallucination through stacking multiple temporal focal transformer blocks + trans_feat = self.ss(enc_feat.view(-1, c, h, w), b) + trans_feat = self.transformer(trans_feat) + trans_feat = self.sc(trans_feat, t) + trans_feat = trans_feat.view(b, t, -1, h, w) + enc_feat = enc_feat + trans_feat + + # decode frames from features + output = self.decoder(enc_feat.view(b * t, c, h, w)) + output = torch.tanh(output) + return output, pred_flows + + +# ###################################################################### +# Discriminator for Temporal Patch GAN +# ###################################################################### + + +class Discriminator(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/inpainter/model/e2fgvi_hq.py b/inpainter/model/e2fgvi_hq.py new file mode 100644 index 0000000..b01ba15 --- /dev/null +++ b/inpainter/model/e2fgvi_hq.py @@ -0,0 +1,350 @@ +''' Towards An End-to-End Framework for Video Inpainting +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model.modules.flow_comp import SPyNet +from model.modules.feat_prop import BidirectionalPropagation, SecondOrderDeformableAlignment +from model.modules.tfocal_transformer_hq import TemporalFocalTransformerBlock, SoftSplit, SoftComp +from model.modules.spectral_norm import spectral_norm as _spectral_norm + + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % + init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.group = [1, 2, 4, 8, 1] + self.layers = nn.ModuleList([ + nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True) + ]) + + def forward(self, x): + bt, c, _, _ = x.size() + # h, w = h//4, w//4 + out = x + for i, layer in enumerate(self.layers): + if i == 8: + x0 = out + _, _, h, w = x0.size() + if i > 8 and i % 2 == 0: + g = self.group[(i - 8) // 2] + x = x0.view(bt, g, -1, h, w) + o = out.view(bt, g, -1, h, w) + out = torch.cat([x, o], 2).view(bt, -1, h, w) + out = layer(out) + return out + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True): + super(InpaintGenerator, self).__init__() + channel = 256 + hidden = 512 + + # encoder + self.encoder = Encoder() + + # decoder + self.decoder = nn.Sequential( + deconv(channel // 2, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + # feature propagation module + self.feat_prop_module = BidirectionalPropagation(channel // 2) + + # soft split and soft composition + kernel_size = (7, 7) + padding = (3, 3) + stride = (3, 3) + output_size = (60, 108) + t2t_params = { + 'kernel_size': kernel_size, + 'stride': stride, + 'padding': padding + } + self.ss = SoftSplit(channel // 2, + hidden, + kernel_size, + stride, + padding, + t2t_param=t2t_params) + self.sc = SoftComp(channel // 2, hidden, kernel_size, stride, padding) + + n_vecs = 1 + for i, d in enumerate(kernel_size): + n_vecs *= int((output_size[i] + 2 * padding[i] - + (d - 1) - 1) / stride[i] + 1) + + blocks = [] + depths = 8 + num_heads = [4] * depths + window_size = [(5, 9)] * depths + focal_windows = [(5, 9)] * depths + focal_levels = [2] * depths + pool_method = "fc" + + for i in range(depths): + blocks.append( + TemporalFocalTransformerBlock(dim=hidden, + num_heads=num_heads[i], + window_size=window_size[i], + focal_level=focal_levels[i], + focal_window=focal_windows[i], + n_vecs=n_vecs, + t2t_params=t2t_params, + pool_method=pool_method)) + self.transformer = nn.Sequential(*blocks) + + if init_weights: + self.init_weights() + # Need to initial the weights of MSDeformAttn specifically + for m in self.modules(): + if isinstance(m, SecondOrderDeformableAlignment): + m.init_offset() + + # flow completion network + self.update_spynet = SPyNet() + + def forward_bidirect_flow(self, masked_local_frames): + b, l_t, c, h, w = masked_local_frames.size() + + # compute forward and backward flows of masked frames + masked_local_frames = F.interpolate(masked_local_frames.view( + -1, c, h, w), + scale_factor=1 / 4, + mode='bilinear', + align_corners=True, + recompute_scale_factor=True) + masked_local_frames = masked_local_frames.view(b, l_t, c, h // 4, + w // 4) + mlf_1 = masked_local_frames[:, :-1, :, :, :].reshape( + -1, c, h // 4, w // 4) + mlf_2 = masked_local_frames[:, 1:, :, :, :].reshape( + -1, c, h // 4, w // 4) + pred_flows_forward = self.update_spynet(mlf_1, mlf_2) + pred_flows_backward = self.update_spynet(mlf_2, mlf_1) + + pred_flows_forward = pred_flows_forward.view(b, l_t - 1, 2, h // 4, + w // 4) + pred_flows_backward = pred_flows_backward.view(b, l_t - 1, 2, h // 4, + w // 4) + + return pred_flows_forward, pred_flows_backward + + def forward(self, masked_frames, num_local_frames): + l_t = num_local_frames + b, t, ori_c, ori_h, ori_w = masked_frames.size() + + # normalization before feeding into the flow completion module + masked_local_frames = (masked_frames[:, :l_t, ...] + 1) / 2 + pred_flows = self.forward_bidirect_flow(masked_local_frames) + + # extracting features and performing the feature propagation on local features + enc_feat = self.encoder(masked_frames.view(b * t, ori_c, ori_h, ori_w)) + _, c, h, w = enc_feat.size() + fold_output_size = (h, w) + local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] + ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] + local_feat = self.feat_prop_module(local_feat, pred_flows[0], + pred_flows[1]) + enc_feat = torch.cat((local_feat, ref_feat), dim=1) + + # content hallucination through stacking multiple temporal focal transformer blocks + trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_output_size) + trans_feat = self.transformer([trans_feat, fold_output_size]) + trans_feat = self.sc(trans_feat[0], t, fold_output_size) + trans_feat = trans_feat.view(b, t, -1, h, w) + enc_feat = enc_feat + trans_feat + + # decode frames from features + output = self.decoder(enc_feat.view(b * t, c, h, w)) + output = torch.tanh(output) + return output, pred_flows + + +# ###################################################################### +# Discriminator for Temporal Patch GAN +# ###################################################################### + + +class Discriminator(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/inpainter/model/modules/feat_prop.py b/inpainter/model/modules/feat_prop.py new file mode 100644 index 0000000..9b9144c --- /dev/null +++ b/inpainter/model/modules/feat_prop.py @@ -0,0 +1,149 @@ +""" + BasicVSR++: Improving Video Super-Resolution with Enhanced Propagation and Alignment, CVPR 2022 +""" +import torch +import torch.nn as nn + +from mmcv.ops import ModulatedDeformConv2d, modulated_deform_conv2d +from mmengine.model import constant_init + +from model.modules.flow_comp import flow_warp + + +class SecondOrderDeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module.""" + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels + 4, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat, flow_1, flow_2): + extra_feat = torch.cat([extra_feat, flow_1, flow_2], dim=1) + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh( + torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset_1 = offset_1 + flow_1.flip(1).repeat(1, + offset_1.size(1) // 2, 1, + 1) + offset_2 = offset_2 + flow_2.flip(1).repeat(1, + offset_2.size(1) // 2, 1, + 1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, + self.stride, self.padding, + self.dilation, self.groups, + self.deform_groups) + + +class BidirectionalPropagation(nn.Module): + def __init__(self, channel): + super(BidirectionalPropagation, self).__init__() + modules = ['backward_', 'forward_'] + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + self.channel = channel + + for i, module in enumerate(modules): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * channel, channel, 3, padding=1, deform_groups=16) + + self.backbone[module] = nn.Sequential( + nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) + + def forward(self, x, flows_backward, flows_forward): + """ + x shape : [b, t, c, h, w] + return [b, t, c, h, w] + """ + b, t, c, h, w = x.shape + feats = {} + feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)] + + for module_name in ['backward_', 'forward_']: + + feats[module_name] = [] + + frame_idx = range(0, t) + flow_idx = range(-1, t - 1) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + flows = flows_backward + else: + flows = flows_forward + + feat_prop = x.new_zeros(b, self.channel, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + + if i > 0: + flow_n1 = flows[:, flow_idx[i], :, :, :] + cond_n1 = flow_warp(feat_prop, flow_n1.permute(0, 2, 3, 1)) + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + flow_n2 = torch.zeros_like(flow_n1) + cond_n2 = torch.zeros_like(cond_n1) + if i > 1: + feat_n2 = feats[module_name][-2] + flow_n2 = flows[:, flow_idx[i - 1], :, :, :] + flow_n2 = flow_n1 + flow_warp( + flow_n2, flow_n1.permute(0, 2, 3, 1)) + cond_n2 = flow_warp(feat_n2, + flow_n2.permute(0, 2, 3, 1)) + + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, + flow_n1, + flow_n2) + + feat = [feat_current] + [ + feats[k][idx] + for k in feats if k not in ['spatial', module_name] + ] + [feat_prop] + + feat = torch.cat(feat, dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + feats[module_name].append(feat_prop) + + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + outputs = [] + for i in range(0, t): + align_feats = [feats[k].pop(0) for k in feats if k != 'spatial'] + align_feats = torch.cat(align_feats, dim=1) + outputs.append(self.fusion(align_feats)) + + return torch.stack(outputs, dim=1) + x diff --git a/inpainter/model/modules/flow_comp.py b/inpainter/model/modules/flow_comp.py new file mode 100644 index 0000000..d3abf2f --- /dev/null +++ b/inpainter/model/modules/flow_comp.py @@ -0,0 +1,450 @@ +import numpy as np + +import torch.nn as nn +import torch.nn.functional as F +import torch + +from mmcv.cnn import ConvModule +from mmengine.runner import load_checkpoint + + +class FlowCompletionLoss(nn.Module): + """Flow completion loss""" + def __init__(self): + super().__init__() + self.fix_spynet = SPyNet() + for p in self.fix_spynet.parameters(): + p.requires_grad = False + + self.l1_criterion = nn.L1Loss() + + def forward(self, pred_flows, gt_local_frames): + b, l_t, c, h, w = gt_local_frames.size() + + with torch.no_grad(): + # compute gt forward and backward flows + gt_local_frames = F.interpolate(gt_local_frames.view(-1, c, h, w), + scale_factor=1 / 4, + mode='bilinear', + align_corners=True, + recompute_scale_factor=True) + gt_local_frames = gt_local_frames.view(b, l_t, c, h // 4, w // 4) + gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape( + -1, c, h // 4, w // 4) + gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape( + -1, c, h // 4, w // 4) + gt_flows_forward = self.fix_spynet(gtlf_1, gtlf_2) + gt_flows_backward = self.fix_spynet(gtlf_2, gtlf_1) + + # calculate loss for flow completion + forward_flow_loss = self.l1_criterion( + pred_flows[0].view(-1, 2, h // 4, w // 4), gt_flows_forward) + backward_flow_loss = self.l1_criterion( + pred_flows[1].view(-1, 2, h // 4, w // 4), gt_flows_backward) + flow_loss = forward_flow_loss + backward_flow_loss + + return flow_loss + + +class SPyNet(nn.Module): + """SPyNet network structure. + The difference to the SPyNet in [tof.py] is that + 1. more SPyNetBasicModule is used in this version, and + 2. no batch normalization is used in this version. + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + Args: + pretrained (str): path for pre-trained SPyNet. Default: None. + """ + def __init__( + self, + use_pretrain=True, + pretrained='https://download.openmmlab.com/mmediting/restorers/basicvsr/spynet_20210409-c6c1bd09.pth' + ): + super().__init__() + + self.basic_module = nn.ModuleList( + [SPyNetBasicModule() for _ in range(6)]) + + if use_pretrain: + if isinstance(pretrained, str): + print("load pretrained SPyNet...") + load_checkpoint(self, pretrained, strict=True) + elif pretrained is not None: + raise TypeError('[pretrained] should be str or None, ' + f'but got {type(pretrained)}.') + + self.register_buffer( + 'mean', + torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer( + 'std', + torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def compute_flow(self, ref, supp): + """Compute flow from ref to supp. + Note that in this function, the images are already resized to a + multiple of 32. + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + n, _, h, w = ref.size() + + # normalize the input images + ref = [(ref - self.mean) / self.std] + supp = [(supp - self.mean) / self.std] + + # generate downsampled frames + for level in range(5): + ref.append( + F.avg_pool2d(input=ref[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + supp.append( + F.avg_pool2d(input=supp[-1], + kernel_size=2, + stride=2, + count_include_pad=False)) + ref = ref[::-1] + supp = supp[::-1] + + # flow computation + flow = ref[0].new_zeros(n, 2, h // 32, w // 32) + for level in range(len(ref)): + if level == 0: + flow_up = flow + else: + flow_up = F.interpolate(input=flow, + scale_factor=2, + mode='bilinear', + align_corners=True) * 2.0 + + # add the residue to the upsampled flow + flow = flow_up + self.basic_module[level](torch.cat([ + ref[level], + flow_warp(supp[level], + flow_up.permute(0, 2, 3, 1).contiguous(), + padding_mode='border'), flow_up + ], 1)) + + return flow + + def forward(self, ref, supp): + """Forward function of SPyNet. + This function computes the optical flow from ref to supp. + Args: + ref (Tensor): Reference image with shape of (n, 3, h, w). + supp (Tensor): Supporting image with shape of (n, 3, h, w). + Returns: + Tensor: Estimated optical flow: (n, 2, h, w). + """ + + # upsize to a multiple of 32 + h, w = ref.shape[2:4] + w_up = w if (w % 32) == 0 else 32 * (w // 32 + 1) + h_up = h if (h % 32) == 0 else 32 * (h // 32 + 1) + ref = F.interpolate(input=ref, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + supp = F.interpolate(input=supp, + size=(h_up, w_up), + mode='bilinear', + align_corners=False) + + # compute flow, and resize back to the original resolution + flow = F.interpolate(input=self.compute_flow(ref, supp), + size=(h, w), + mode='bilinear', + align_corners=False) + + # adjust the flow values + flow[:, 0, :, :] *= float(w) / float(w_up) + flow[:, 1, :, :] *= float(h) / float(h_up) + + return flow + + +class SPyNetBasicModule(nn.Module): + """Basic Module for SPyNet. + Paper: + Optical Flow Estimation using a Spatial Pyramid Network, CVPR, 2017 + """ + def __init__(self): + super().__init__() + + self.basic_module = nn.Sequential( + ConvModule(in_channels=8, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule(in_channels=32, + out_channels=64, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule(in_channels=64, + out_channels=32, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule(in_channels=32, + out_channels=16, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=dict(type='ReLU')), + ConvModule(in_channels=16, + out_channels=2, + kernel_size=7, + stride=1, + padding=3, + norm_cfg=None, + act_cfg=None)) + + def forward(self, tensor_input): + """ + Args: + tensor_input (Tensor): Input tensor with shape (b, 8, h, w). + 8 channels contain: + [reference image (3), neighbor image (3), initial flow (2)]. + Returns: + Tensor: Refined flow with shape (b, 2, h, w) + """ + return self.basic_module(tensor_input) + + +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) + col = col + RY + # YG + colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) + colorwheel[col:col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col:col + GC, 1] = 255 + colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) + col = col + GC + # CB + colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) + colorwheel[col:col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col:col + BM, 2] = 255 + colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) + col = col + BM + # MR + colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) + colorwheel[col:col + MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u) / np.pi + fk = (a + 1) / 2 * (ncols - 1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:, i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1 - col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2 - i if convert_to_bgr else i + flow_image[:, :, ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:, :, 0] + v = flow_uv[:, :, 1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) + + +def flow_warp(x, + flow, + interpolation='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or a feature map with optical flow. + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is + a two-channel, denoting the width and height relative offsets. + Note that the values are not normalized to [-1, 1]. + interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. + Default: 'bilinear'. + padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Whether align corners. Default: True. + Returns: + Tensor: Warped image or feature map. + """ + if x.size()[-2:] != flow.size()[1:3]: + raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' + f'flow ({flow.size()[1:3]}) are not the same.') + _, _, h, w = x.size() + # create mesh grid + grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) + grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) + grid.requires_grad = False + + grid_flow = grid + flow + # scale grid_flow to [-1,1] + grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 + grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) + output = F.grid_sample(x, + grid_flow, + mode=interpolation, + padding_mode=padding_mode, + align_corners=align_corners) + return output + + +def initial_mask_flow(mask): + """ + mask 1 indicates valid pixel 0 indicates unknown pixel + """ + B, T, C, H, W = mask.shape + + # calculate relative position + grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) + + grid_y, grid_x = grid_y.type_as(mask), grid_x.type_as(mask) + abs_relative_pos_y = H - torch.abs(grid_y[None, :, :] - grid_y[:, None, :]) + relative_pos_y = H - (grid_y[None, :, :] - grid_y[:, None, :]) + + abs_relative_pos_x = W - torch.abs(grid_x[:, None, :] - grid_x[:, :, None]) + relative_pos_x = W - (grid_x[:, None, :] - grid_x[:, :, None]) + + # calculate the nearest indices + pos_up = mask.unsqueeze(3).repeat( + 1, 1, 1, H, 1, 1).flip(4) * abs_relative_pos_y[None, None, None] * ( + relative_pos_y <= H)[None, None, None] + nearest_indice_up = pos_up.max(dim=4)[1] + + pos_down = mask.unsqueeze(3).repeat(1, 1, 1, H, 1, 1) * abs_relative_pos_y[ + None, None, None] * (relative_pos_y <= H)[None, None, None] + nearest_indice_down = (pos_down).max(dim=4)[1] + + pos_left = mask.unsqueeze(4).repeat( + 1, 1, 1, 1, W, 1).flip(5) * abs_relative_pos_x[None, None, None] * ( + relative_pos_x <= W)[None, None, None] + nearest_indice_left = (pos_left).max(dim=5)[1] + + pos_right = mask.unsqueeze(4).repeat( + 1, 1, 1, 1, W, 1) * abs_relative_pos_x[None, None, None] * ( + relative_pos_x <= W)[None, None, None] + nearest_indice_right = (pos_right).max(dim=5)[1] + + # NOTE: IMPORTANT !!! depending on how to use this offset + initial_offset_up = -(nearest_indice_up - grid_y[None, None, None]).flip(3) + initial_offset_down = nearest_indice_down - grid_y[None, None, None] + + initial_offset_left = -(nearest_indice_left - + grid_x[None, None, None]).flip(4) + initial_offset_right = nearest_indice_right - grid_x[None, None, None] + + # nearest_indice_x = (mask.unsqueeze(1).repeat(1, img_width, 1) * relative_pos_x).max(dim=2)[1] + # initial_offset_x = nearest_indice_x - grid_x + + # handle the boundary cases + final_offset_down = (initial_offset_down < 0) * initial_offset_up + ( + initial_offset_down > 0) * initial_offset_down + final_offset_up = (initial_offset_up > 0) * initial_offset_down + ( + initial_offset_up < 0) * initial_offset_up + final_offset_right = (initial_offset_right < 0) * initial_offset_left + ( + initial_offset_right > 0) * initial_offset_right + final_offset_left = (initial_offset_left > 0) * initial_offset_right + ( + initial_offset_left < 0) * initial_offset_left + zero_offset = torch.zeros_like(final_offset_down) + # out = torch.cat([final_offset_left, zero_offset, final_offset_right, zero_offset, zero_offset, final_offset_up, zero_offset, final_offset_down], dim=2) + out = torch.cat([ + zero_offset, final_offset_left, zero_offset, final_offset_right, + final_offset_up, zero_offset, final_offset_down, zero_offset + ], + dim=2) + + return out diff --git a/inpainter/model/modules/spectral_norm.py b/inpainter/model/modules/spectral_norm.py new file mode 100644 index 0000000..f38c34e --- /dev/null +++ b/inpainter/model/modules/spectral_norm.py @@ -0,0 +1,288 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize + + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + 'Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), + dim=0, + eps=self.eps, + out=v) + u = normalize(torch.mv(weight_mat, v), + dim=0, + eps=self.eps, + out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, + torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr( + module, self.name, + self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), + weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + "Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook( + SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', + {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + # weight = state_dict.pop(prefix + fn.name) + # sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + # v = fn._solve_v_and_rescale(weight_mat, u, sigma) + # state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError( + "Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, + name='weight', + n_power_iterations=1, + eps=1e-12, + dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format( + name, module)) + + +def use_spectral_norm(module, use_sn=False): + if use_sn: + return spectral_norm(module) + return module \ No newline at end of file diff --git a/inpainter/model/modules/tfocal_transformer.py b/inpainter/model/modules/tfocal_transformer.py new file mode 100644 index 0000000..179508f --- /dev/null +++ b/inpainter/model/modules/tfocal_transformer.py @@ -0,0 +1,536 @@ +""" + This code is based on: + [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021 + https://github.com/ruiliu-ai/FuseFormer + [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021 + https://github.com/yitu-opensource/T2T-ViT + [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021 + https://github.com/microsoft/Focal-Transformer +""" + +import math +from functools import reduce + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SoftSplit(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding, + t2t_param): + super(SoftSplit, self).__init__() + self.kernel_size = kernel_size + self.t2t = nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=padding) + c_in = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(c_in, hidden) + + self.f_h = int( + (t2t_param['output_size'][0] + 2 * t2t_param['padding'][0] - + (t2t_param['kernel_size'][0] - 1) - 1) / t2t_param['stride'][0] + + 1) + self.f_w = int( + (t2t_param['output_size'][1] + 2 * t2t_param['padding'][1] - + (t2t_param['kernel_size'][1] - 1) - 1) / t2t_param['stride'][1] + + 1) + + def forward(self, x, b): + feat = self.t2t(x) + feat = feat.permute(0, 2, 1) + # feat shape [b*t, num_vec, ks*ks*c] + feat = self.embedding(feat) + # feat shape after embedding [b, t*num_vec, hidden] + feat = feat.view(b, -1, self.f_h, self.f_w, feat.size(2)) + return feat + + +class SoftComp(nn.Module): + def __init__(self, channel, hidden, output_size, kernel_size, stride, + padding): + super(SoftComp, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.t2t = torch.nn.Fold(output_size=output_size, + kernel_size=kernel_size, + stride=stride, + padding=padding) + h, w = output_size + self.bias = nn.Parameter(torch.zeros((channel, h, w), + dtype=torch.float32), + requires_grad=True) + + def forward(self, x, t): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = self.t2t(feat) + self.bias[None] + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, d_model, n_vecs=None, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set d_ff as a default to 1960 + hd = 1960 + self.conv1 = nn.Sequential(nn.Linear(d_model, hd)) + self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model)) + assert t2t_params is not None and n_vecs is not None + tp = t2t_params.copy() + self.fold = nn.Fold(**tp) + del tp['output_size'] + self.unfold = nn.Unfold(**tp) + self.n_vecs = n_vecs + + def forward(self, x): + x = self.conv1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, 49).view(-1, self.n_vecs, + 49).permute(0, 2, 1) + x = self.unfold( + self.fold(x.view(-1, self.n_vecs, c).permute(0, 2, 1)) / + self.fold(normalizer)).permute(0, 2, 1).contiguous().view(b, n, c) + x = self.conv2(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B*num_windows, T*window_size*window_size, C) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], + window_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view( + -1, T * window_size[0] * window_size[1], C) + return windows + + +def window_partition_noreshape(x, window_size): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], + window_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous() + return windows + + +def window_reverse(windows, window_size, T, H, W): + """ + Args: + windows: shape is (num_windows*B, T, window_size, window_size, C) + window_size (tuple[int]): Window size + T (int): Temporal length of video + H (int): Height of image + W (int): Width of image + Returns: + x: (B, T, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], T, + window_size[0], window_size[1], -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Temporal focal window attention + """ + def __init__(self, dim, expand_size, window_size, focal_window, + focal_level, num_heads, qkv_bias, pool_method): + + super().__init__() + self.dim = dim + self.expand_size = expand_size + self.window_size = window_size # Wh, Ww + self.pool_method = pool_method + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.focal_level = focal_level + self.focal_window = focal_window + + if any(i > 0 for i in self.expand_size) and focal_level > 0: + # get mask for rolled k and rolled v + mask_tl = torch.ones(self.window_size[0], self.window_size[1]) + mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0 + mask_tr = torch.ones(self.window_size[0], self.window_size[1]) + mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0 + mask_bl = torch.ones(self.window_size[0], self.window_size[1]) + mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0 + mask_br = torch.ones(self.window_size[0], self.window_size[1]) + mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0 + mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), + 0).flatten(0) + self.register_buffer("valid_ind_rolled", + mask_rolled.nonzero(as_tuple=False).view(-1)) + + if pool_method != "none" and focal_level > 1: + self.unfolds = nn.ModuleList() + + # build relative position bias between local patch and pooled windows + for k in range(focal_level - 1): + stride = 2**k + kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1) + for i in self.focal_window) + # define unfolding operations + self.unfolds += [ + nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=tuple(i // 2 for i in kernel_size)) + ] + + # define unfolding index for focal_level > 0 + if k > 0: + mask = torch.zeros(kernel_size) + mask[(2**k) - 1:, (2**k) - 1:] = 1 + self.register_buffer( + "valid_ind_unfold_{}".format(k), + mask.flatten(0).nonzero(as_tuple=False).view(-1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x_all, mask_all=None): + """ + Args: + x: input features with shape of (B, T, Wh, Ww, C) + mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None + + output: (nW*B, Wh*Ww, C) + """ + x = x_all[0] + + B, T, nH, nW, C = x.shape + qkv = self.qkv(x).reshape(B, T, nH, nW, 3, + C).permute(4, 0, 1, 2, 3, 5).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C + + # partition q map + (q_windows, k_windows, v_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4). + contiguous().view(-1, self.num_heads, T * self.window_size[ + 0] * self.window_size[1], C // self.num_heads), (q, k, v)) + # q(k/v)_windows shape : [16, 4, 225, 128] + + if any(i > 0 for i in self.expand_size) and self.focal_level > 0: + (k_tl, v_tl) = map( + lambda t: torch.roll(t, + shifts=(-self.expand_size[0], -self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_tr, v_tr) = map( + lambda t: torch.roll(t, + shifts=(-self.expand_size[0], self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_bl, v_bl) = map( + lambda t: torch.roll(t, + shifts=(self.expand_size[0], -self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_br, v_br) = map( + lambda t: torch.roll(t, + shifts=(self.expand_size[0], self. + expand_size[1]), + dims=(2, 3)), (k, v)) + + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br)) + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br)) + k_rolled = torch.cat( + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), + 2).permute(0, 3, 1, 2, 4).contiguous() + v_rolled = torch.cat( + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), + 2).permute(0, 3, 1, 2, 4).contiguous() + + # mask out tokens in current window + k_rolled = k_rolled[:, :, :, self.valid_ind_rolled] + v_rolled = v_rolled[:, :, :, self.valid_ind_rolled] + temp_N = k_rolled.shape[3] + k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N, + C // self.num_heads) + v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N, + C // self.num_heads) + k_rolled = torch.cat((k_windows, k_rolled), 2) + v_rolled = torch.cat((v_windows, v_rolled), 2) + else: + k_rolled = k_windows + v_rolled = v_windows + + # q(k/v)_windows shape : [16, 4, 225, 128] + # k_rolled.shape : [16, 4, 5, 165, 128] + # ideal expanded window size 153 ((5+2*2)*(9+2*4)) + # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2) + + if self.pool_method != "none" and self.focal_level > 1: + k_pooled = [] + v_pooled = [] + for k in range(self.focal_level - 1): + stride = 2**k + x_window_pooled = x_all[k + 1].permute( + 0, 3, 1, 2, 4).contiguous() # B, T, nWh, nWw, C + + nWh, nWw = x_window_pooled.shape[2:4] + + # generate mask for pooled windows + mask = x_window_pooled.new(T, nWh, nWw).fill_(1) + # unfold mask: [nWh*nWw//s//s, k*k, 1] + unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view( + 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\ + view(nWh*nWw // stride // stride, -1, 1) + + if k > 0: + valid_ind_unfold_k = getattr( + self, "valid_ind_unfold_{}".format(k)) + unfolded_mask = unfolded_mask[:, valid_ind_unfold_k] + + x_window_masks = unfolded_mask.flatten(1).unsqueeze(0) + x_window_masks = x_window_masks.masked_fill( + x_window_masks == 0, + float(-100.0)).masked_fill(x_window_masks > 0, float(0.0)) + mask_all[k + 1] = x_window_masks + + # generate k and v for pooled windows + qkv_pooled = self.qkv(x_window_pooled).reshape( + B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2, + 3).view(3, -1, C, nWh, + nWw).contiguous() + k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[ + 2] # B*T, C, nWh, nWw + # k_pooled_k shape: [5, 512, 4, 4] + # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16] + + (k_pooled_k, v_pooled_k) = map( + lambda t: self.unfolds[k](t).view( + B, T, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 5, 1, 3, 4, 2).contiguous().\ + view(-1, T, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4).contiguous(), + (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim + ) + # k_pooled_k shape : [16, 4, 5, 45, 128] + + # select valid unfolding index + if k > 0: + (k_pooled_k, v_pooled_k) = map( + lambda t: t[:, :, :, valid_ind_unfold_k], + (k_pooled_k, v_pooled_k)) + + k_pooled_k = k_pooled_k.view( + -1, self.num_heads, T * self.unfolds[k].kernel_size[0] * + self.unfolds[k].kernel_size[1], C // self.num_heads) + v_pooled_k = v_pooled_k.view( + -1, self.num_heads, T * self.unfolds[k].kernel_size[0] * + self.unfolds[k].kernel_size[1], C // self.num_heads) + + k_pooled += [k_pooled_k] + v_pooled += [v_pooled_k] + + # k_all (v_all) shape : [16, 4, 5 * 210, 128] + k_all = torch.cat([k_rolled] + k_pooled, 2) + v_all = torch.cat([v_rolled] + v_pooled, 2) + else: + k_all = k_rolled + v_all = v_rolled + + N = k_all.shape[-2] + q_windows = q_windows * self.scale + attn = ( + q_windows @ k_all.transpose(-2, -1) + ) # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size + # T * 45 + window_area = T * self.window_size[0] * self.window_size[1] + # T * 165 + window_area_rolled = k_rolled.shape[2] + + if self.pool_method != "none" and self.focal_level > 1: + offset = window_area_rolled + for k in range(self.focal_level - 1): + # add attentional mask + # mask_all[1] shape [1, 16, T * 45] + + bias = tuple((i + 2**k - 1) for i in self.focal_window) + + if mask_all[k + 1] is not None: + attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \ + attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \ + mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1]) + + offset += T * bias[0] * bias[1] + + if mask_all[0] is not None: + nW = mask_all[0].shape[0] + attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, + window_area, N) + attn[:, :, :, :, : + window_area] = attn[:, :, :, :, :window_area] + mask_all[0][ + None, :, None, :, :] + attn = attn.view(-1, self.num_heads, window_area, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, + C) + x = self.proj(x) + return x + + +class TemporalFocalTransformerBlock(nn.Module): + r""" Temporal Focal Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): The number level of focal window. + focal_window (int): Window size of each focal window. + n_vecs (int): Required for F3N. + t2t_params (int): T2T parameters for F3N. + """ + def __init__(self, + dim, + num_heads, + window_size=(5, 9), + mlp_ratio=4., + qkv_bias=True, + pool_method="fc", + focal_level=2, + focal_window=(5, 9), + norm_layer=nn.LayerNorm, + n_vecs=None, + t2t_params=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.expand_size = tuple(i // 2 for i in window_size) # TODO + self.mlp_ratio = mlp_ratio + self.pool_method = pool_method + self.focal_level = focal_level + self.focal_window = focal_window + + self.window_size_glo = self.window_size + + self.pool_layers = nn.ModuleList() + if self.pool_method != "none": + for k in range(self.focal_level - 1): + window_size_glo = tuple( + math.floor(i / (2**k)) for i in self.window_size_glo) + self.pool_layers.append( + nn.Linear(window_size_glo[0] * window_size_glo[1], 1)) + self.pool_layers[-1].weight.data.fill_( + 1. / (window_size_glo[0] * window_size_glo[1])) + self.pool_layers[-1].bias.data.fill_(0) + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention(dim, + expand_size=self.expand_size, + window_size=self.window_size, + focal_window=focal_window, + focal_level=focal_level, + num_heads=num_heads, + qkv_bias=qkv_bias, + pool_method=pool_method) + + self.norm2 = norm_layer(dim) + self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params) + + def forward(self, x): + B, T, H, W, C = x.shape + + shortcut = x + x = self.norm1(x) + + shifted_x = x + + x_windows_all = [shifted_x] + x_window_masks_all = [None] + + # partition windows tuple(i // 2 for i in window_size) + if self.focal_level > 1 and self.pool_method != "none": + # if we add coarser granularity and the pool method is not none + for k in range(self.focal_level - 1): + window_size_glo = tuple( + math.floor(i / (2**k)) for i in self.window_size_glo) + pooled_h = math.ceil(H / window_size_glo[0]) * (2**k) + pooled_w = math.ceil(W / window_size_glo[1]) * (2**k) + H_pool = pooled_h * window_size_glo[0] + W_pool = pooled_w * window_size_glo[1] + + x_level_k = shifted_x + # trim or pad shifted_x depending on the required size + if H > H_pool: + trim_t = (H - H_pool) // 2 + trim_b = H - H_pool - trim_t + x_level_k = x_level_k[:, :, trim_t:-trim_b] + elif H < H_pool: + pad_t = (H_pool - H) // 2 + pad_b = H_pool - H - pad_t + x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b)) + + if W > W_pool: + trim_l = (W - W_pool) // 2 + trim_r = W - W_pool - trim_l + x_level_k = x_level_k[:, :, :, trim_l:-trim_r] + elif W < W_pool: + pad_l = (W_pool - W) // 2 + pad_r = W_pool - W - pad_l + x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r)) + + x_windows_noreshape = window_partition_noreshape( + x_level_k.contiguous(), window_size_glo + ) # B, nw, nw, T, window_size, window_size, C + nWh, nWw = x_windows_noreshape.shape[1:3] + x_windows_noreshape = x_windows_noreshape.view( + B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1], + C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2 + x_windows_pooled = self.pool_layers[k]( + x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C + + x_windows_all += [x_windows_pooled] + x_window_masks_all += [None] + + attn_windows = self.attn( + x_windows_all, + mask_all=x_window_masks_all) # nW*B, T*window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, T, self.window_size[0], + self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, T, H, + W) # B T H' W' C + + # FFN + x = shortcut + shifted_x + y = self.norm2(x) + x = x + self.mlp(y.view(B, T * H * W, C)).view(B, T, H, W, C) + + return x diff --git a/inpainter/model/modules/tfocal_transformer_hq.py b/inpainter/model/modules/tfocal_transformer_hq.py new file mode 100644 index 0000000..1a24dfa --- /dev/null +++ b/inpainter/model/modules/tfocal_transformer_hq.py @@ -0,0 +1,565 @@ +""" + This code is based on: + [1] FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting, ICCV 2021 + https://github.com/ruiliu-ai/FuseFormer + [2] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet, ICCV 2021 + https://github.com/yitu-opensource/T2T-ViT + [3] Focal Self-attention for Local-Global Interactions in Vision Transformers, NeurIPS 2021 + https://github.com/microsoft/Focal-Transformer +""" + +import math +from functools import reduce + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SoftSplit(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding, + t2t_param): + super(SoftSplit, self).__init__() + self.kernel_size = kernel_size + self.t2t = nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=padding) + c_in = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(c_in, hidden) + + self.t2t_param = t2t_param + + def forward(self, x, b, output_size): + f_h = int((output_size[0] + 2 * self.t2t_param['padding'][0] - + (self.t2t_param['kernel_size'][0] - 1) - 1) / + self.t2t_param['stride'][0] + 1) + f_w = int((output_size[1] + 2 * self.t2t_param['padding'][1] - + (self.t2t_param['kernel_size'][1] - 1) - 1) / + self.t2t_param['stride'][1] + 1) + + feat = self.t2t(x) + feat = feat.permute(0, 2, 1) + # feat shape [b*t, num_vec, ks*ks*c] + feat = self.embedding(feat) + # feat shape after embedding [b, t*num_vec, hidden] + feat = feat.view(b, -1, f_h, f_w, feat.size(2)) + return feat + + +class SoftComp(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(SoftComp, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias_conv = nn.Conv2d(channel, + channel, + kernel_size=3, + stride=1, + padding=1) + # TODO upsample conv + # self.bias_conv = nn.Conv2d() + # self.bias = nn.Parameter(torch.zeros((channel, h, w), dtype=torch.float32), requires_grad=True) + + def forward(self, x, t, output_size): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = F.fold(feat, + output_size=output_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + feat = self.bias_conv(feat) + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, d_model, n_vecs=None, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set d_ff as a default to 1960 + hd = 1960 + self.conv1 = nn.Sequential(nn.Linear(d_model, hd)) + self.conv2 = nn.Sequential(nn.GELU(), nn.Linear(hd, d_model)) + assert t2t_params is not None and n_vecs is not None + self.t2t_params = t2t_params + + def forward(self, x, output_size): + n_vecs = 1 + for i, d in enumerate(self.t2t_params['kernel_size']): + n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - + (d - 1) - 1) / self.t2t_params['stride'][i] + 1) + + x = self.conv1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, 49).view(-1, n_vecs, 49).permute(0, 2, 1) + normalizer = F.fold(normalizer, + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.unfold(x / normalizer, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']).permute( + 0, 2, 1).contiguous().view(b, n, c) + x = self.conv2(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B*num_windows, T*window_size*window_size, C) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], + window_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous().view( + -1, T * window_size[0] * window_size[1], C) + return windows + + +def window_partition_noreshape(x, window_size): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B, num_windows_h, num_windows_w, T, window_size, window_size, C) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], + window_size[1], C) + windows = x.permute(0, 2, 4, 1, 3, 5, 6).contiguous() + return windows + + +def window_reverse(windows, window_size, T, H, W): + """ + Args: + windows: shape is (num_windows*B, T, window_size, window_size, C) + window_size (tuple[int]): Window size + T (int): Temporal length of video + H (int): Height of image + W (int): Width of image + Returns: + x: (B, T, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view(B, H // window_size[0], W // window_size[1], T, + window_size[0], window_size[1], -1) + x = x.permute(0, 3, 1, 4, 2, 5, 6).contiguous().view(B, T, H, W, -1) + return x + + +class WindowAttention(nn.Module): + """Temporal focal window attention + """ + def __init__(self, dim, expand_size, window_size, focal_window, + focal_level, num_heads, qkv_bias, pool_method): + + super().__init__() + self.dim = dim + self.expand_size = expand_size + self.window_size = window_size # Wh, Ww + self.pool_method = pool_method + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + self.focal_level = focal_level + self.focal_window = focal_window + + if any(i > 0 for i in self.expand_size) and focal_level > 0: + # get mask for rolled k and rolled v + mask_tl = torch.ones(self.window_size[0], self.window_size[1]) + mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0 + mask_tr = torch.ones(self.window_size[0], self.window_size[1]) + mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0 + mask_bl = torch.ones(self.window_size[0], self.window_size[1]) + mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0 + mask_br = torch.ones(self.window_size[0], self.window_size[1]) + mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0 + mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), + 0).flatten(0) + self.register_buffer("valid_ind_rolled", + mask_rolled.nonzero(as_tuple=False).view(-1)) + + if pool_method != "none" and focal_level > 1: + self.unfolds = nn.ModuleList() + + # build relative position bias between local patch and pooled windows + for k in range(focal_level - 1): + stride = 2**k + kernel_size = tuple(2 * (i // 2) + 2**k + (2**k - 1) + for i in self.focal_window) + # define unfolding operations + self.unfolds += [ + nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=tuple(i // 2 for i in kernel_size)) + ] + + # define unfolding index for focal_level > 0 + if k > 0: + mask = torch.zeros(kernel_size) + mask[(2**k) - 1:, (2**k) - 1:] = 1 + self.register_buffer( + "valid_ind_unfold_{}".format(k), + mask.flatten(0).nonzero(as_tuple=False).view(-1)) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.proj = nn.Linear(dim, dim) + + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x_all, mask_all=None): + """ + Args: + x: input features with shape of (B, T, Wh, Ww, C) + mask: (0/-inf) mask with shape of (num_windows, T*Wh*Ww, T*Wh*Ww) or None + + output: (nW*B, Wh*Ww, C) + """ + x = x_all[0] + + B, T, nH, nW, C = x.shape + qkv = self.qkv(x).reshape(B, T, nH, nW, 3, + C).permute(4, 0, 1, 2, 3, 5).contiguous() + q, k, v = qkv[0], qkv[1], qkv[2] # B, T, nH, nW, C + + # partition q map + (q_windows, k_windows, v_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads).permute(0, 3, 1, 2, 4). + contiguous().view(-1, self.num_heads, T * self.window_size[ + 0] * self.window_size[1], C // self.num_heads), (q, k, v)) + # q(k/v)_windows shape : [16, 4, 225, 128] + + if any(i > 0 for i in self.expand_size) and self.focal_level > 0: + (k_tl, v_tl) = map( + lambda t: torch.roll(t, + shifts=(-self.expand_size[0], -self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_tr, v_tr) = map( + lambda t: torch.roll(t, + shifts=(-self.expand_size[0], self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_bl, v_bl) = map( + lambda t: torch.roll(t, + shifts=(self.expand_size[0], -self. + expand_size[1]), + dims=(2, 3)), (k, v)) + (k_br, v_br) = map( + lambda t: torch.roll(t, + shifts=(self.expand_size[0], self. + expand_size[1]), + dims=(2, 3)), (k, v)) + + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads), (k_tl, k_tr, k_bl, k_br)) + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( + lambda t: window_partition(t, self.window_size).view( + -1, T, self.window_size[0] * self.window_size[1], self. + num_heads, C // self.num_heads), (v_tl, v_tr, v_bl, v_br)) + k_rolled = torch.cat( + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), + 2).permute(0, 3, 1, 2, 4).contiguous() + v_rolled = torch.cat( + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), + 2).permute(0, 3, 1, 2, 4).contiguous() + + # mask out tokens in current window + k_rolled = k_rolled[:, :, :, self.valid_ind_rolled] + v_rolled = v_rolled[:, :, :, self.valid_ind_rolled] + temp_N = k_rolled.shape[3] + k_rolled = k_rolled.view(-1, self.num_heads, T * temp_N, + C // self.num_heads) + v_rolled = v_rolled.view(-1, self.num_heads, T * temp_N, + C // self.num_heads) + k_rolled = torch.cat((k_windows, k_rolled), 2) + v_rolled = torch.cat((v_windows, v_rolled), 2) + else: + k_rolled = k_windows + v_rolled = v_windows + + # q(k/v)_windows shape : [16, 4, 225, 128] + # k_rolled.shape : [16, 4, 5, 165, 128] + # ideal expanded window size 153 ((5+2*2)*(9+2*4)) + # k_windows=45 expand_window=108 overlap_window=12 (since expand_size < window_size / 2) + + if self.pool_method != "none" and self.focal_level > 1: + k_pooled = [] + v_pooled = [] + for k in range(self.focal_level - 1): + stride = 2**k + # B, T, nWh, nWw, C + x_window_pooled = x_all[k + 1].permute(0, 3, 1, 2, + 4).contiguous() + + nWh, nWw = x_window_pooled.shape[2:4] + + # generate mask for pooled windows + mask = x_window_pooled.new(T, nWh, nWw).fill_(1) + # unfold mask: [nWh*nWw//s//s, k*k, 1] + unfolded_mask = self.unfolds[k](mask.unsqueeze(1)).view( + 1, T, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(4, 1, 2, 3, 0).contiguous().\ + view(nWh*nWw // stride // stride, -1, 1) + + if k > 0: + valid_ind_unfold_k = getattr( + self, "valid_ind_unfold_{}".format(k)) + unfolded_mask = unfolded_mask[:, valid_ind_unfold_k] + + x_window_masks = unfolded_mask.flatten(1).unsqueeze(0) + x_window_masks = x_window_masks.masked_fill( + x_window_masks == 0, + float(-100.0)).masked_fill(x_window_masks > 0, float(0.0)) + mask_all[k + 1] = x_window_masks + + # generate k and v for pooled windows + qkv_pooled = self.qkv(x_window_pooled).reshape( + B, T, nWh, nWw, 3, C).permute(4, 0, 1, 5, 2, + 3).view(3, -1, C, nWh, + nWw).contiguous() + # B*T, C, nWh, nWw + k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] + # k_pooled_k shape: [5, 512, 4, 4] + # self.unfolds[k](k_pooled_k) shape: [5, 23040 (512 * 5 * 9 ), 16] + + (k_pooled_k, v_pooled_k) = map( + lambda t: self.unfolds[k] + (t).view(B, T, C, self.unfolds[k].kernel_size[0], self. + unfolds[k].kernel_size[1], -1) + .permute(0, 5, 1, 3, 4, 2).contiguous().view( + -1, T, self.unfolds[k].kernel_size[0] * self.unfolds[ + k].kernel_size[1], self.num_heads, C // self. + num_heads).permute(0, 3, 1, 2, 4).contiguous(), + # (B x (nH*nW)) x nHeads x T x (unfold_wsize x unfold_wsize) x head_dim + (k_pooled_k, v_pooled_k)) + # k_pooled_k shape : [16, 4, 5, 45, 128] + + # select valid unfolding index + if k > 0: + (k_pooled_k, v_pooled_k) = map( + lambda t: t[:, :, :, valid_ind_unfold_k], + (k_pooled_k, v_pooled_k)) + + k_pooled_k = k_pooled_k.view( + -1, self.num_heads, T * self.unfolds[k].kernel_size[0] * + self.unfolds[k].kernel_size[1], C // self.num_heads) + v_pooled_k = v_pooled_k.view( + -1, self.num_heads, T * self.unfolds[k].kernel_size[0] * + self.unfolds[k].kernel_size[1], C // self.num_heads) + + k_pooled += [k_pooled_k] + v_pooled += [v_pooled_k] + + # k_all (v_all) shape : [16, 4, 5 * 210, 128] + k_all = torch.cat([k_rolled] + k_pooled, 2) + v_all = torch.cat([v_rolled] + v_pooled, 2) + else: + k_all = k_rolled + v_all = v_rolled + + N = k_all.shape[-2] + q_windows = q_windows * self.scale + # B*nW, nHead, T*window_size*window_size, T*focal_window_size*focal_window_size + attn = (q_windows @ k_all.transpose(-2, -1)) + # T * 45 + window_area = T * self.window_size[0] * self.window_size[1] + # T * 165 + window_area_rolled = k_rolled.shape[2] + + if self.pool_method != "none" and self.focal_level > 1: + offset = window_area_rolled + for k in range(self.focal_level - 1): + # add attentional mask + # mask_all[1] shape [1, 16, T * 45] + + bias = tuple((i + 2**k - 1) for i in self.focal_window) + + if mask_all[k + 1] is not None: + attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] = \ + attn[:, :, :window_area, offset:(offset + (T*bias[0]*bias[1]))] + \ + mask_all[k+1][:, :, None, None, :].repeat( + attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1]) + + offset += T * bias[0] * bias[1] + + if mask_all[0] is not None: + nW = mask_all[0].shape[0] + attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, + window_area, N) + attn[:, :, :, :, : + window_area] = attn[:, :, :, :, :window_area] + mask_all[0][ + None, :, None, :, :] + attn = attn.view(-1, self.num_heads, window_area, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, + C) + x = self.proj(x) + return x + + +class TemporalFocalTransformerBlock(nn.Module): + r""" Temporal Focal Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (tuple[int]): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + focal_level (int): The number level of focal window. + focal_window (int): Window size of each focal window. + n_vecs (int): Required for F3N. + t2t_params (int): T2T parameters for F3N. + """ + def __init__(self, + dim, + num_heads, + window_size=(5, 9), + mlp_ratio=4., + qkv_bias=True, + pool_method="fc", + focal_level=2, + focal_window=(5, 9), + norm_layer=nn.LayerNorm, + n_vecs=None, + t2t_params=None): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.expand_size = tuple(i // 2 for i in window_size) # TODO + self.mlp_ratio = mlp_ratio + self.pool_method = pool_method + self.focal_level = focal_level + self.focal_window = focal_window + + self.window_size_glo = self.window_size + + self.pool_layers = nn.ModuleList() + if self.pool_method != "none": + for k in range(self.focal_level - 1): + window_size_glo = tuple( + math.floor(i / (2**k)) for i in self.window_size_glo) + self.pool_layers.append( + nn.Linear(window_size_glo[0] * window_size_glo[1], 1)) + self.pool_layers[-1].weight.data.fill_( + 1. / (window_size_glo[0] * window_size_glo[1])) + self.pool_layers[-1].bias.data.fill_(0) + + self.norm1 = norm_layer(dim) + + self.attn = WindowAttention(dim, + expand_size=self.expand_size, + window_size=self.window_size, + focal_window=focal_window, + focal_level=focal_level, + num_heads=num_heads, + qkv_bias=qkv_bias, + pool_method=pool_method) + + self.norm2 = norm_layer(dim) + self.mlp = FusionFeedForward(dim, n_vecs=n_vecs, t2t_params=t2t_params) + + def forward(self, x): + output_size = x[1] + x = x[0] + + B, T, H, W, C = x.shape + + shortcut = x + x = self.norm1(x) + + shifted_x = x + + x_windows_all = [shifted_x] + x_window_masks_all = [None] + + # partition windows tuple(i // 2 for i in window_size) + if self.focal_level > 1 and self.pool_method != "none": + # if we add coarser granularity and the pool method is not none + for k in range(self.focal_level - 1): + window_size_glo = tuple( + math.floor(i / (2**k)) for i in self.window_size_glo) + pooled_h = math.ceil(H / window_size_glo[0]) * (2**k) + pooled_w = math.ceil(W / window_size_glo[1]) * (2**k) + H_pool = pooled_h * window_size_glo[0] + W_pool = pooled_w * window_size_glo[1] + + x_level_k = shifted_x + # trim or pad shifted_x depending on the required size + if H > H_pool: + trim_t = (H - H_pool) // 2 + trim_b = H - H_pool - trim_t + x_level_k = x_level_k[:, :, trim_t:-trim_b] + elif H < H_pool: + pad_t = (H_pool - H) // 2 + pad_b = H_pool - H - pad_t + x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b)) + + if W > W_pool: + trim_l = (W - W_pool) // 2 + trim_r = W - W_pool - trim_l + x_level_k = x_level_k[:, :, :, trim_l:-trim_r] + elif W < W_pool: + pad_l = (W_pool - W) // 2 + pad_r = W_pool - W - pad_l + x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r)) + + x_windows_noreshape = window_partition_noreshape( + x_level_k.contiguous(), window_size_glo + ) # B, nw, nw, T, window_size, window_size, C + nWh, nWw = x_windows_noreshape.shape[1:3] + x_windows_noreshape = x_windows_noreshape.view( + B, nWh, nWw, T, window_size_glo[0] * window_size_glo[1], + C).transpose(4, 5) # B, nWh, nWw, T, C, wsize**2 + x_windows_pooled = self.pool_layers[k]( + x_windows_noreshape).flatten(-2) # B, nWh, nWw, T, C + + x_windows_all += [x_windows_pooled] + x_window_masks_all += [None] + + # nW*B, T*window_size*window_size, C + attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all) + + # merge windows + attn_windows = attn_windows.view(-1, T, self.window_size[0], + self.window_size[1], C) + shifted_x = window_reverse(attn_windows, self.window_size, T, H, + W) # B T H' W' C + + # FFN + x = shortcut + shifted_x + y = self.norm2(x) + x = x + self.mlp(y.view(B, T * H * W, C), output_size).view( + B, T, H, W, C) + + return x, output_size diff --git a/inpainter/util/__init__.py b/inpainter/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/inpainter/util/tensor_util.py b/inpainter/util/tensor_util.py new file mode 100644 index 0000000..71a4746 --- /dev/null +++ b/inpainter/util/tensor_util.py @@ -0,0 +1,24 @@ +import cv2 +import numpy as np + +# resize frames +def resize_frames(frames, size=None): + """ + size: (w, h) + """ + if size is not None: + frames = [cv2.resize(f, size) for f in frames] + frames = np.stack(frames, 0) + + return frames + +# resize frames +def resize_masks(masks, size=None): + """ + size: (w, h) + """ + if size is not None: + masks = [np.expand_dims(cv2.resize(m, size), 2) for m in masks] + masks = np.stack(masks, 0) + + return masks