import torch import torch.nn as nn import numpy as np from torch.optim import AdamW import torch.optim as optim import itertools from warplayer import warp from torch.nn.parallel import DistributedDataParallel as DDP from IFNet import * import torch.nn.functional as F from loss import * device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), nn.PReLU(out_planes) ) def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): return nn.Sequential( nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=True), ) def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): return nn.Sequential( torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), nn.PReLU(out_planes) ) class ResBlock(nn.Module): def __init__(self, in_planes, out_planes, stride=2): super(ResBlock, self).__init__() if in_planes == out_planes and stride == 1: self.conv0 = nn.Identity() else: self.conv0 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False) self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.relu1 = nn.PReLU(1) self.relu2 = nn.PReLU(out_planes) self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False) self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False) def forward(self, x): y = self.conv0(x) x = self.conv1(x) x = self.conv2(x) w = x.mean(3, True).mean(2, True) w = self.relu1(self.fc1(w)) w = torch.sigmoid(self.fc2(w)) x = self.relu2(x * w + y) return x c = 16 class ContextNet(nn.Module): def __init__(self): super(ContextNet, self).__init__() self.conv1 = ResBlock(3, c) self.conv2 = ResBlock(c, 2*c) self.conv3 = ResBlock(2*c, 4*c) self.conv4 = ResBlock(4*c, 8*c) def forward(self, x, flow): x = self.conv1(x) f1 = warp(x, flow) x = self.conv2(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f2 = warp(x, flow) x = self.conv3(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f3 = warp(x, flow) x = self.conv4(x) flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 f4 = warp(x, flow) return [f1, f2, f3, f4] class FusionNet(nn.Module): def __init__(self): super(FusionNet, self).__init__() self.down0 = ResBlock(8, 2*c) self.down1 = ResBlock(4*c, 4*c) self.down2 = ResBlock(8*c, 8*c) self.down3 = ResBlock(16*c, 16*c) self.up0 = deconv(32*c, 8*c) self.up1 = deconv(16*c, 4*c) self.up2 = deconv(8*c, 2*c) self.up3 = deconv(4*c, c) self.conv = nn.Conv2d(c, 4, 3, 1, 1) def forward(self, img0, img1, flow, c0, c1, flow_gt): warped_img0 = warp(img0, flow) warped_img1 = warp(img1, -flow) if flow_gt == None: warped_img0_gt, warped_img1_gt = None, None else: warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img1_gt = warp(img1, flow_gt[:, 2:4]) s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1)) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) x = self.up1(torch.cat((x, s2), 1)) x = self.up2(torch.cat((x, s1), 1)) x = self.up3(torch.cat((x, s0), 1)) x = self.conv(x) return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt class Model: def __init__(self, local_rank=-1): self.flownet = IFNet() self.contextnet = ContextNet() self.fusionnet = FusionNet() self.device() self.optimG = AdamW(itertools.chain( self.flownet.parameters(), self.contextnet.parameters(), self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.schedulerG = optim.lr_scheduler.CyclicLR( self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.epe = EPE() self.ter = Ternary() self.sobel = SOBEL() if local_rank != -1: self.flownet = DDP(self.flownet, device_ids=[ local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[ local_rank], output_device=local_rank) self.fusionnet = DDP(self.fusionnet, device_ids=[ local_rank], output_device=local_rank) def train(self): self.flownet.train() self.contextnet.train() self.fusionnet.train() def eval(self): self.flownet.eval() self.contextnet.eval() self.fusionnet.eval() def device(self): self.flownet.to(device) self.contextnet.to(device) self.fusionnet.to(device) def load_model(self, path, rank=0): if rank == 0: self.flownet.load_state_dict( torch.load('{}/flownet.pkl'.format(path))) self.contextnet.load_state_dict( torch.load('{}/contextnet.pkl'.format(path))) self.fusionnet.load_state_dict( torch.load('{}/unet.pkl'.format(path))) def save_model(self, path, rank=0): if rank == 0: torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) def predict(self, imgs, flow, training=True, flow_gt=None): img0 = imgs[:, :3] img1 = imgs[:, 3:] c0 = self.contextnet(img0, flow) c1 = self.contextnet(img1, -flow) flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( img0, img1, flow, c0, c1, flow_gt) res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 mask = torch.sigmoid(refine_output[:, 3:4]) merged_img = warped_img0 * mask + warped_img1 * (1 - mask) pred = merged_img + res pred = torch.clamp(pred, 0, 1) if training: return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt else: return pred def inference(self, imgs): with torch.no_grad(): flow, _ = self.flownet(imgs) return self.predict(imgs, flow, training=False) def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): for param_group in self.optimG.param_groups: param_group['lr'] = learning_rate if training: self.train() # with torch.no_grad(): # flow_gt = estimate(gt, img0) else: self.eval() flow, flow_list = self.flownet(imgs) pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict( imgs, flow, flow_gt=flow_gt) loss_ter = self.ter(pred, gt).mean() if training: with torch.no_grad(): loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_mask = torch.abs( merged_img - gt).sum(1, True).float().detach() loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False).detach() flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5).detach() loss_cons = 0 for i in range(3): loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons = loss_cons.mean() * 0.01 else: loss_cons = torch.tensor([0]) loss_flow = torch.abs(warped_img0 - gt).mean() loss_mask = 1 loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean() if training: self.optimG.zero_grad() loss_G = loss_l1 + loss_cons + loss_ter loss_G.backward() self.optimG.step() return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask if __name__ == '__main__': img0 = torch.zeros(3, 3, 256, 256).float().to(device) img1 = torch.tensor(np.random.normal( 0, 1, (3, 3, 256, 256))).float().to(device) imgs = torch.cat((img0, img1), 1) model = Model() model.eval() print(model.inference(imgs).shape)