Add privileged training arXiv v6

This commit is contained in:
hzwer
2021-08-13 16:34:41 +08:00
parent 618ddb30e6
commit 144207f213
12 changed files with 297 additions and 1208 deletions

View File

@@ -2,6 +2,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
from model.refine import *
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
@@ -9,12 +10,6 @@ def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
@@ -23,9 +18,8 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
@@ -40,37 +34,72 @@ class IFBlock(nn.Module):
conv(c, c),
conv(c, c),
)
self.conv1 = nn.ConvTranspose2d(c, 4, 4, 2, 1)
self.lastconv = nn.ConvTranspose2d(c, 5, 4, 2, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
def forward(self, x, flow, scale):
if scale != 1:
x = F.interpolate(x, scale_factor = 1. / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor = 1. / scale, mode="bilinear", align_corners=False) * 1. / scale
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
return flow
tmp = self.lastconv(x)
tmp = F.interpolate(tmp, scale_factor = scale * 2, mode="bilinear", align_corners=False)
flow = tmp[:, :4] * scale * 2
mask = tmp[:, 4:5]
return flow, mask
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=240)
self.block1 = IFBlock(10, scale=2, c=150)
self.block2 = IFBlock(10, scale=1, c=90)
self.block0 = IFBlock(6, c=240)
self.block1 = IFBlock(13+4, c=150)
self.block2 = IFBlock(13+4, c=90)
self.block_tea = IFBlock(16+4, c=90)
self.contextnet = Contextnet()
self.unet = Unet()
def forward(self, x):
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]
def forward(self, x, scale=[4,2,1]):
img0 = x[:, :3]
img1 = x[:, 3:6]
gt = x[:, 6:] # In inference time, gt is None
flow_list = []
merged = []
mask_list = []
warped_img0 = img0
warped_img1 = img1
flow = None
loss_distill = 0
stu = [self.block0, self.block1, self.block2]
for i in range(3):
if flow != None:
flow_d, mask_d = stu[i](torch.cat((img0, img1, warped_img0, warped_img1, mask), 1), flow, scale=scale[i])
flow = flow + flow_d
mask = mask + mask_d
else:
flow, mask = stu[i](torch.cat((img0, img1), 1), None, scale=scale[i])
mask_list.append(torch.sigmoid(mask))
flow_list.append(flow)
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
merged_student = (warped_img0, warped_img1)
merged.append(merged_student)
if gt.shape[1] == 3:
flow_d, mask_d = self.block_tea(torch.cat((img0, img1, warped_img0, warped_img1, mask, gt), 1), flow, scale=1)
flow_teacher = flow + flow_d
warped_img0_teacher = warp(img0, flow_teacher[:, :2])
warped_img1_teacher = warp(img1, flow_teacher[:, 2:4])
mask_teacher = torch.sigmoid(mask + mask_d)
merged_teacher = warped_img0_teacher * mask_teacher + warped_img1_teacher * (1 - mask_teacher)
for i in range(3):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
res = tmp[:, :3] * 2 - 1
merged[2] = torch.clamp(merged[2] + res, 0, 1)
return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill

View File

@@ -1,76 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=320)
self.block1 = IFBlock(10, scale=2, c=225)
self.block2 = IFBlock(10, scale=1, c=135)
def forward(self, x):
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@@ -1,75 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=240)
self.block1 = IFBlock(10, scale=2, c=150)
self.block2 = IFBlock(10, scale=1, c=90)
def forward(self, x):
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@@ -1,75 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=360)
self.block1 = IFBlock(10, scale=2, c=225)
self.block2 = IFBlock(10, scale=1, c=135)
def forward(self, x):
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@@ -5,233 +5,87 @@ from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torchstat import stat
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet import *
import torch.nn.functional as F
from model.loss import *
from model.laplacian import *
from model.refine import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 16
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.down0 = Conv2(12, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
device = torch.device("cuda")
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.optimG = AdamW(self.flownet.parameters(), lr=1e-6, weight_decay=1e-4)
self.epe = EPE()
self.ter = Ternary()
self.lap = LapLoss()
self.sobel = SOBEL()
# self.vgg = VGGPerceptualLoss().to(device)
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
def load_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path)))
def save_model(self, path, rank=0):
if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None):
'''
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = merged + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
if training:
return pred, merged
else:
return pred
def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(torch.cat((img0, img1), 1))
return self.predict(imgs, flow, training=False)
'''
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(torch.cat((imgs, gt), 1), scale=[4, 2, 1])
loss_l1 = (self.lap(merged[2], gt)).mean()
loss_tea = (self.lap(merged_teacher, gt)).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
# loss_G = self.vgg(pred, gt) + loss_cons + loss_ter
loss_G = loss_l1 + loss_tea + loss_distill * 0.01
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)
else:
flow_teacher = flow[2]
merged_teacher = merged[2]
return merged[2], {
'merged_tea': merged_teacher,
'mask': mask,
'mask_tea': mask,
'flow': flow[2][:, :2],
'flow_tea': flow_teacher,
'loss_l1': loss_l1,
'loss_tea': loss_tea,
'loss_distill': loss_distill,
}

View File

@@ -1,235 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet15C import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 24
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.down0 = Conv2(12, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@@ -1,235 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet2F import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 16
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv1 = Conv2(3, c, 1)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.down0 = Conv2(12, 2*c, 1)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@@ -1,235 +0,0 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet2F15C import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 24
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv1 = Conv2(3, c, 1)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.down0 = Conv2(12, 2*c, 1)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

59
model/laplacian.py Normal file
View File

@@ -0,0 +1,59 @@
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import torch
def gauss_kernel(size=5, channels=3):
kernel = torch.tensor([[1., 4., 6., 4., 1],
[4., 16., 24., 16., 4.],
[6., 24., 36., 24., 6.],
[4., 16., 24., 16., 4.],
[1., 4., 6., 4., 1.]])
kernel /= 256.
kernel = kernel.repeat(channels, 1, 1, 1)
kernel = kernel.to(device)
return kernel
def downsample(x):
return x[:, :, ::2, ::2]
def upsample(x):
cc = torch.cat([x, torch.zeros(x.shape[0], x.shape[1], x.shape[2], x.shape[3]).to(device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[2]*2, x.shape[3])
cc = cc.permute(0,1,3,2)
cc = torch.cat([cc, torch.zeros(x.shape[0], x.shape[1], x.shape[3], x.shape[2]*2).to(device)], dim=3)
cc = cc.view(x.shape[0], x.shape[1], x.shape[3]*2, x.shape[2]*2)
x_up = cc.permute(0,1,3,2)
return conv_gauss(x_up, 4*gauss_kernel(channels=x.shape[1]))
def conv_gauss(img, kernel):
img = torch.nn.functional.pad(img, (2, 2, 2, 2), mode='reflect')
out = torch.nn.functional.conv2d(img, kernel, groups=img.shape[1])
return out
def laplacian_pyramid(img, kernel, max_levels=3):
current = img
pyr = []
for level in range(max_levels):
filtered = conv_gauss(current, kernel)
down = downsample(filtered)
up = upsample(down)
diff = current-up
pyr.append(diff)
current = down
return pyr
class LapLoss(torch.nn.Module):
def __init__(self, max_levels=5, channels=3):
super(LapLoss, self).__init__()
self.max_levels = max_levels
self.gauss_kernel = gauss_kernel(channels=channels)
def forward(self, input, target):
pyr_input = laplacian_pyramid(img=input, kernel=self.gauss_kernel, max_levels=self.max_levels)
pyr_target = laplacian_pyramid(img=target, kernel=self.gauss_kernel, max_levels=self.max_levels)
return sum(torch.nn.functional.l1_loss(a, b) for a, b in zip(pyr_input, pyr_target))

82
model/refine.py Normal file
View File

@@ -0,0 +1,82 @@
import torch
import torch.nn as nn
import numpy as np
import torch.optim as optim
import itertools
from model.warplayer import warp
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 16
class Contextnet(nn.Module):
def __init__(self):
super(Contextnet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class Unet(nn.Module):
def __init__(self):
super(Unet, self).__init__()
self.down0 = Conv2(17, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 3, 3, 1, 1)
def forward(self, img0, img1, warped_img0, warped_img1, mask, flow, c0, c1):
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1, mask, flow), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return torch.sigmoid(x)