Updated RIFE CUDA to 3.1

This commit is contained in:
N00MKRAD
2021-05-17 19:42:53 +02:00
parent d993d7c13a
commit 0a3df58e4c
18 changed files with 950 additions and 426 deletions

View File

@@ -1,115 +1,76 @@
import torch import torch
import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model.warplayer import warp from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes), )
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(ResBlock, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x):
y = self.conv0(x)
x = self.conv1(x)
x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x
class IFBlock(nn.Module): class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64): def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__() super(IFBlock, self).__init__()
self.scale = scale self.scale = scale
self.conv0 = conv(in_planes, c, 3, 2, 1) self.conv0 = nn.Sequential(
self.res0 = ResBlock(c, c) conv(in_planes, c//2, 3, 2, 1),
self.res1 = ResBlock(c, c) conv(c//2, c, 3, 2, 1),
self.res2 = ResBlock(c, c) )
self.res3 = ResBlock(c, c) self.convblock = nn.Sequential(
self.res4 = ResBlock(c, c) conv(c, c),
self.res5 = ResBlock(c, c) conv(c, c),
self.conv1 = nn.Conv2d(c, 8, 3, 1, 1) conv(c, c),
self.up = nn.PixelShuffle(2) conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, x): def forward(self, x):
if self.scale != 1: if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
align_corners=False)
x = self.conv0(x) x = self.conv0(x)
x = self.res0(x) x = self.convblock(x) + x
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.conv1(x) x = self.conv1(x)
flow = self.up(x) flow = x
if self.scale != 1: if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
align_corners=False)
return flow return flow
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self):
super(IFNet, self).__init__() super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=192) self.block0 = IFBlock(6, scale=4, c=240)
self.block1 = IFBlock(8, scale=2, c=128) self.block1 = IFBlock(10, scale=2, c=150)
self.block2 = IFBlock(8, scale=1, c=64) self.block2 = IFBlock(10, scale=1, c=90)
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F1) warped_img0 = warp(x[:, :3], F1_large[:, :2])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1)) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1) F2 = (flow0 + flow1)
warped_img0 = warp(x[:, :3], F2) F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F2) warped_img0 = warp(x[:, :3], F2_large[:, :2])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1)) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3] return F3, [F1, F2, F3]
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
flownet = IFNet()
flow, _ = flownet(imgs)
print(flow.shape)

View File

@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__()
self.scale = scale
self.conv0 = nn.Sequential(
conv(in_planes, c//2, 3, 2, 1),
conv(c//2, c, 3, 2, 1),
)
self.convblock = nn.Sequential(
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
x = self.conv0(x)
x = self.convblock(x) + x
x = self.conv1(x)
flow = x
if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=320)
self.block1 = IFBlock(10, scale=2, c=225)
self.block2 = IFBlock(10, scale=1, c=135)
def forward(self, x):
flow0 = self.block0(x)
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@@ -1,115 +1,75 @@
import torch import torch
import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model.warplayer import warp from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes), )
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(ResBlock, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x):
y = self.conv0(x)
x = self.conv1(x)
x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x
class IFBlock(nn.Module): class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64): def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__() super(IFBlock, self).__init__()
self.scale = scale self.scale = scale
self.conv0 = conv(in_planes, c, 3, 1, 1) self.conv0 = nn.Sequential(
self.res0 = ResBlock(c, c) conv(in_planes, c, 3, 2, 1),
self.res1 = ResBlock(c, c) )
self.res2 = ResBlock(c, c) self.convblock = nn.Sequential(
self.res3 = ResBlock(c, c) conv(c, c),
self.res4 = ResBlock(c, c) conv(c, c),
self.res5 = ResBlock(c, c) conv(c, c),
self.conv1 = nn.Conv2d(c, 2, 3, 1, 1) conv(c, c),
self.up = nn.PixelShuffle(2) conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, x): def forward(self, x):
if self.scale != 1: if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
align_corners=False)
x = self.conv0(x) x = self.conv0(x)
x = self.res0(x) x = self.convblock(x) + x
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.conv1(x) x = self.conv1(x)
flow = x # self.up(x) flow = x
if self.scale != 1: if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
align_corners=False)
return flow return flow
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self):
super(IFNet, self).__init__() super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=192) self.block0 = IFBlock(6, scale=4, c=240)
self.block1 = IFBlock(8, scale=2, c=128) self.block1 = IFBlock(10, scale=2, c=150)
self.block2 = IFBlock(8, scale=1, c=64) self.block2 = IFBlock(10, scale=1, c=90)
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F1) warped_img0 = warp(x[:, :3], F1_large[:, :2])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1)) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1) F2 = (flow0 + flow1)
warped_img0 = warp(x[:, :3], F2) F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F2) warped_img0 = warp(x[:, :3], F2_large[:, :2])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1)) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3] return F3, [F1, F2, F3]
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
flownet = IFNet()
flow, _ = flownet(imgs)
print(flow.shape)

View File

@@ -1,115 +1,75 @@
import torch import torch
import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from model.warplayer import warp from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes), )
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=True),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=1):
super(ResBlock, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x):
y = self.conv0(x)
x = self.conv1(x)
x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x
class IFBlock(nn.Module): class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64): def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__() super(IFBlock, self).__init__()
self.scale = scale self.scale = scale
self.conv0 = conv(in_planes, c, 3, 1, 1) self.conv0 = nn.Sequential(
self.res0 = ResBlock(c, c) conv(in_planes, c, 3, 2, 1),
self.res1 = ResBlock(c, c) )
self.res2 = ResBlock(c, c) self.convblock = nn.Sequential(
self.res3 = ResBlock(c, c) conv(c, c),
self.res4 = ResBlock(c, c) conv(c, c),
self.res5 = ResBlock(c, c) conv(c, c),
self.conv1 = nn.Conv2d(c, 2, 3, 1, 1) conv(c, c),
self.up = nn.PixelShuffle(2) conv(c, c),
conv(c, c),
conv(c, c),
conv(c, c),
)
self.conv1 = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, x): def forward(self, x):
if self.scale != 1: if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False)
align_corners=False)
x = self.conv0(x) x = self.conv0(x)
x = self.res0(x) x = self.convblock(x) + x
x = self.res1(x)
x = self.res2(x)
x = self.res3(x)
x = self.res4(x)
x = self.res5(x)
x = self.conv1(x) x = self.conv1(x)
flow = x # self.up(x) flow = x
if self.scale != 1: if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False)
align_corners=False)
return flow return flow
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self):
super(IFNet, self).__init__() super(IFNet, self).__init__()
self.block0 = IFBlock(6, scale=4, c=288) self.block0 = IFBlock(6, scale=4, c=360)
self.block1 = IFBlock(8, scale=2, c=192) self.block1 = IFBlock(10, scale=2, c=225)
self.block2 = IFBlock(8, scale=1, c=96) self.block2 = IFBlock(10, scale=1, c=135)
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F1) warped_img0 = warp(x[:, :3], F1_large[:, :2])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1), 1)) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1) F2 = (flow0 + flow1)
warped_img0 = warp(x[:, :3], F2) F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img1 = warp(x[:, 3:], -F2) warped_img0 = warp(x[:, :3], F2_large[:, :2])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1)) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3] return F3, [F1, F2, F3]
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
flownet = IFNet()
flow, _ = flownet(imgs)
print(flow.shape)

View File

@@ -91,12 +91,9 @@ class IFNet(nn.Module):
self.block2 = IFBlock(8, scale=2, c=96) self.block2 = IFBlock(8, scale=2, c=96)
self.block3 = IFBlock(8, scale=1, c=48) self.block3 = IFBlock(8, scale=1, c=48)
def forward(self, x, UHD=False): def forward(self, x, scale=1.0):
if UHD: x = F.interpolate(x, scale_factor=0.5 * scale, mode="bilinear",
x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False) align_corners=False)
else:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) warped_img0 = warp(x[:, :3], F1)
@@ -111,6 +108,8 @@ class IFNet(nn.Module):
warped_img1 = warp(x[:, 3:], -F3) warped_img1 = warp(x[:, 3:], -F3)
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1)) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3), 1))
F4 = (flow0 + flow1 + flow2 + flow3) F4 = (flow0 + flow1 + flow2 + flow3)
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear",
align_corners=False) / scale
return F4, [F1, F2, F3, F4] return F4, [F1, F2, F3, F4]
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -61,26 +61,28 @@ class IFNet(nn.Module):
self.block2 = IFBlock(10, scale=2, c=96) self.block2 = IFBlock(10, scale=2, c=96)
self.block3 = IFBlock(10, scale=1, c=48) self.block3 = IFBlock(10, scale=1, c=48)
def forward(self, x, UHD=False): def forward(self, x, scale=1.0):
if UHD: if scale != 1.0:
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False) x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2]) warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1))
F2 = (flow0 + flow1) F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F3_large[:, :2]) warped_img0 = warp(x[:, :3], F3_large[:, :2])
warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) warped_img1 = warp(x[:, 3:], F3_large[:, 2:4])
flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1))
F4 = (flow0 + flow1 + flow2 + flow3) F4 = (flow0 + flow1 + flow2 + flow3)
if scale != 1.0:
F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale
return F4, [F1, F2, F3, F4] return F4, [F1, F2, F3, F4]
if __name__ == '__main__': if __name__ == '__main__':

View File

@@ -0,0 +1,81 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from model.warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
class IFBlock(nn.Module):
def __init__(self, in_planes, c=64):
super(IFBlock, self).__init__()
self.conv0 = nn.Sequential(
conv(in_planes, c, 3, 2, 1),
conv(c, 2*c, 3, 2, 1),
)
self.convblock0 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.convblock1 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.convblock2 = nn.Sequential(
conv(2*c, 2*c),
conv(2*c, 2*c),
)
self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1)
def forward(self, x, flow=None, scale=1):
x = F.interpolate(x, scale_factor= 1. / scale, mode="bilinear", align_corners=False)
if flow != None:
flow = F.interpolate(flow, scale_factor= 1. / scale, mode="bilinear", align_corners=False) * (1. / scale)
x = torch.cat((x, flow), 1)
x = self.conv0(x)
x = self.convblock0(x) + x
x = self.convblock1(x) + x
x = self.convblock2(x) + x
x = self.conv1(x)
flow = x
if scale != 1:
flow = F.interpolate(flow, scale_factor= scale, mode="bilinear", align_corners=False) * scale
return flow
class IFNet(nn.Module):
def __init__(self):
super(IFNet, self).__init__()
self.block0 = IFBlock(6, c=80)
self.block1 = IFBlock(10, c=80)
self.block2 = IFBlock(10, c=80)
def forward(self, x, scale_list=[4,2,1]):
flow0 = self.block0(x, scale=scale_list[0])
F1 = flow0
F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F1_large[:, :2])
warped_img1 = warp(x[:, 3:], F1_large[:, 2:4])
flow1 = self.block1(torch.cat((warped_img0, warped_img1), 1), F1_large, scale=scale_list[1])
F2 = (flow0 + flow1)
F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
warped_img0 = warp(x[:, :3], F2_large[:, :2])
warped_img1 = warp(x[:, 3:], F2_large[:, 2:4])
flow2 = self.block2(torch.cat((warped_img0, warped_img1), 1), F2_large, scale=scale_list[2])
F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3]

View File

@@ -34,29 +34,15 @@ def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilati
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
) )
class ResBlock(nn.Module): class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2): def __init__(self, in_planes, out_planes, stride=2):
super(ResBlock, self).__init__() super(Conv2, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
y = self.conv0(x)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x return x
c = 16 c = 16
@@ -64,36 +50,32 @@ c = 16
class ContextNet(nn.Module): class ContextNet(nn.Module):
def __init__(self): def __init__(self):
super(ContextNet, self).__init__() super(ContextNet, self).__init__()
self.conv1 = ResBlock(3, c) self.conv1 = Conv2(3, c)
self.conv2 = ResBlock(c, 2*c) self.conv2 = Conv2(c, 2*c)
self.conv3 = ResBlock(2*c, 4*c) self.conv3 = Conv2(2*c, 4*c)
self.conv4 = ResBlock(4*c, 8*c) self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow): def forward(self, x, flow):
x = self.conv1(x) x = self.conv1(x)
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f4 = warp(x, flow) f4 = warp(x, flow)
return [f1, f2, f3, f4] return [f1, f2, f3, f4]
class FusionNet(nn.Module): class FusionNet(nn.Module):
def __init__(self): def __init__(self):
super(FusionNet, self).__init__() super(FusionNet, self).__init__()
self.down0 = ResBlock(8, 2*c) self.down0 = Conv2(12, 2*c)
self.down1 = ResBlock(4*c, 4*c) self.down1 = Conv2(4*c, 4*c)
self.down2 = ResBlock(8*c, 8*c) self.down2 = Conv2(8*c, 8*c)
self.down3 = ResBlock(16*c, 16*c) self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c) self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c) self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c) self.up2 = deconv(8*c, 2*c)
@@ -101,14 +83,14 @@ class FusionNet(nn.Module):
self.conv = nn.Conv2d(c, 4, 3, 1, 1) self.conv = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt): def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow) warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, -flow) warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None: if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None warped_img0_gt, warped_img1_gt = None, None
else: else:
warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4]) warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1)) s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
@@ -119,7 +101,6 @@ class FusionNet(nn.Module):
x = self.conv(x) x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model: class Model:
def __init__(self, local_rank=-1): def __init__(self, local_rank=-1):
self.flownet = IFNet() self.flownet = IFNet()
@@ -129,12 +110,13 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
self.ter = Ternary() self.ter = Ternary()
self.sobel = SOBEL() self.sobel = SOBEL()
# self.vgg = VGGPerceptualLoss().to(device)
if local_rank != -1: if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[ self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank) local_rank], output_device=local_rank)
@@ -158,7 +140,7 @@ class Model:
self.contextnet.to(device) self.contextnet.to(device)
self.fusionnet.to(device) self.fusionnet.to(device)
def load_model(self, path, rank): def load_model(self, path, rank=-1):
def convert(param): def convert(param):
if rank == -1: if rank == -1:
return { return {
@@ -185,8 +167,8 @@ class Model:
def predict(self, imgs, flow, training=True, flow_gt=None): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0 align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
@@ -203,7 +185,7 @@ class Model:
def inference(self, img0, img1): def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs) flow, _ = self.flownet(torch.cat((img0, img1), 1))
return self.predict(imgs, flow, training=False) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
@@ -228,8 +210,8 @@ class Model:
align_corners=False) * 0.5).detach() align_corners=False) * 0.5).detach()
loss_cons = 0 loss_cons = 0
for i in range(3): for i in range(3):
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01 loss_cons = loss_cons.mean() * 0.01
else: else:
loss_cons = torch.tensor([0]) loss_cons = torch.tensor([0])
@@ -239,6 +221,7 @@ class Model:
if training: if training:
self.optimG.zero_grad() self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter loss_G = loss_l1 + loss_cons + loss_ter
# loss_G = self.vgg(pred, gt) + loss_cons + loss_ter
loss_G.backward() loss_G.backward()
self.optimG.step() self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask

View File

@@ -0,0 +1,235 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet15C import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 24
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv1 = Conv2(3, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv1(x)
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.down0 = Conv2(12, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.Conv2d(c, 4, 3, 1, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank=-1):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1):
imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@@ -34,29 +34,15 @@ def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilati
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
) )
class ResBlock(nn.Module): class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2): def __init__(self, in_planes, out_planes, stride=2):
super(ResBlock, self).__init__() super(Conv2, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
y = self.conv0(x)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x return x
c = 16 c = 16
@@ -64,36 +50,32 @@ c = 16
class ContextNet(nn.Module): class ContextNet(nn.Module):
def __init__(self): def __init__(self):
super(ContextNet, self).__init__() super(ContextNet, self).__init__()
self.conv1 = ResBlock(3, c, 1) self.conv1 = Conv2(3, c, 1)
self.conv2 = ResBlock(c, 2*c) self.conv2 = Conv2(c, 2*c)
self.conv3 = ResBlock(2*c, 4*c) self.conv3 = Conv2(2*c, 4*c)
self.conv4 = ResBlock(4*c, 8*c) self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow): def forward(self, x, flow):
x = self.conv1(x) x = self.conv1(x)
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f4 = warp(x, flow) f4 = warp(x, flow)
return [f1, f2, f3, f4] return [f1, f2, f3, f4]
class FusionNet(nn.Module): class FusionNet(nn.Module):
def __init__(self): def __init__(self):
super(FusionNet, self).__init__() super(FusionNet, self).__init__()
self.down0 = ResBlock(8, 2*c, 1) self.down0 = Conv2(12, 2*c, 1)
self.down1 = ResBlock(4*c, 4*c) self.down1 = Conv2(4*c, 4*c)
self.down2 = ResBlock(8*c, 8*c) self.down2 = Conv2(8*c, 8*c)
self.down3 = ResBlock(16*c, 16*c) self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c) self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c) self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c) self.up2 = deconv(8*c, 2*c)
@@ -101,14 +83,14 @@ class FusionNet(nn.Module):
self.conv = nn.Conv2d(c, 4, 3, 2, 1) self.conv = nn.Conv2d(c, 4, 3, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt): def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow) warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, -flow) warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None: if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None warped_img0_gt, warped_img1_gt = None, None
else: else:
warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4]) warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1)) s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
@@ -119,7 +101,6 @@ class FusionNet(nn.Module):
x = self.conv(x) x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model: class Model:
def __init__(self, local_rank=-1): def __init__(self, local_rank=-1):
self.flownet = IFNet() self.flownet = IFNet()
@@ -129,7 +110,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@@ -158,14 +139,17 @@ class Model:
self.contextnet.to(device) self.contextnet.to(device)
self.fusionnet.to(device) self.fusionnet.to(device)
def load_model(self, path, rank=0): def load_model(self, path, rank=-1):
def convert(param): def convert(param):
return { if rank == -1:
k.replace("module.", ""): v return {
for k, v in param.items() k.replace("module.", ""): v
if "module." in k for k, v in param.items()
} if "module." in k
if rank == 0: }
else:
return param
if rank <= 0:
self.flownet.load_state_dict( self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict( self.contextnet.load_state_dict(
@@ -173,21 +157,19 @@ class Model:
self.fusionnet.load_state_dict( self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank=0): def save_model(self, path, rank):
if rank == 0: if rank == 0:
torch.save(self.flownet.state_dict(), torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
'{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.contextnet.state_dict(),
'{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0 align_corners=False) * 2.0
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, flow[:, 2:4])
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt) img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
@@ -201,9 +183,8 @@ class Model:
return pred return pred
def inference(self, img0, img1): def inference(self, img0, img1):
with torch.no_grad(): imgs = torch.cat((img0, img1), 1)
imgs = torch.cat((img0, img1), 1) flow, _ = self.flownet(imgs)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
@@ -222,10 +203,14 @@ class Model:
loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs( loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach() merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0 loss_cons = 0
for i in range(3): for i in range(3):
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01 loss_cons = loss_cons.mean() * 0.01
else: else:
loss_cons = torch.tensor([0]) loss_cons = torch.tensor([0])

View File

@@ -34,29 +34,15 @@ def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilati
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
) )
class ResBlock(nn.Module): class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2): def __init__(self, in_planes, out_planes, stride=2):
super(ResBlock, self).__init__() super(Conv2, self).__init__()
if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity()
else:
self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1)
self.relu2 = nn.PReLU(out_planes)
self.fc1 = nn.Conv2d(out_planes, 16, kernel_size=1, bias=False)
self.fc2 = nn.Conv2d(16, out_planes, kernel_size=1, bias=False)
def forward(self, x): def forward(self, x):
y = self.conv0(x)
x = self.conv1(x) x = self.conv1(x)
x = self.conv2(x) x = self.conv2(x)
w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y)
return x return x
c = 24 c = 24
@@ -64,36 +50,32 @@ c = 24
class ContextNet(nn.Module): class ContextNet(nn.Module):
def __init__(self): def __init__(self):
super(ContextNet, self).__init__() super(ContextNet, self).__init__()
self.conv1 = ResBlock(3, c, 1) self.conv1 = Conv2(3, c, 1)
self.conv2 = ResBlock(c, 2*c) self.conv2 = Conv2(c, 2*c)
self.conv3 = ResBlock(2*c, 4*c) self.conv3 = Conv2(2*c, 4*c)
self.conv4 = ResBlock(4*c, 8*c) self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow): def forward(self, x, flow):
x = self.conv1(x) x = self.conv1(x)
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
align_corners=False) * 0.5
f4 = warp(x, flow) f4 = warp(x, flow)
return [f1, f2, f3, f4] return [f1, f2, f3, f4]
class FusionNet(nn.Module): class FusionNet(nn.Module):
def __init__(self): def __init__(self):
super(FusionNet, self).__init__() super(FusionNet, self).__init__()
self.down0 = ResBlock(8, 2*c, 1) self.down0 = Conv2(12, 2*c, 1)
self.down1 = ResBlock(4*c, 4*c) self.down1 = Conv2(4*c, 4*c)
self.down2 = ResBlock(8*c, 8*c) self.down2 = Conv2(8*c, 8*c)
self.down3 = ResBlock(16*c, 16*c) self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c) self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c) self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c) self.up2 = deconv(8*c, 2*c)
@@ -101,14 +83,14 @@ class FusionNet(nn.Module):
self.conv = nn.Conv2d(c, 4, 3, 2, 1) self.conv = nn.Conv2d(c, 4, 3, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt): def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow) warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, -flow) warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None: if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None warped_img0_gt, warped_img1_gt = None, None
else: else:
warped_img0_gt = warp(img0, flow_gt[:, :2]) warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4]) warped_img1_gt = warp(img1, flow_gt[:, 2:4])
s0 = self.down0(torch.cat((warped_img0, warped_img1, flow), 1)) s0 = self.down0(torch.cat((img0, img1, warped_img0, warped_img1), 1))
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1)) s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
@@ -119,7 +101,6 @@ class FusionNet(nn.Module):
x = self.conv(x) x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model: class Model:
def __init__(self, local_rank=-1): def __init__(self, local_rank=-1):
self.flownet = IFNet() self.flownet = IFNet()
@@ -129,7 +110,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@@ -158,14 +139,17 @@ class Model:
self.contextnet.to(device) self.contextnet.to(device)
self.fusionnet.to(device) self.fusionnet.to(device)
def load_model(self, path, rank=0): def load_model(self, path, rank=-1):
def convert(param): def convert(param):
return { if rank == -1:
k.replace("module.", ""): v return {
for k, v in param.items() k.replace("module.", ""): v
if "module." in k for k, v in param.items()
} if "module." in k
if rank == 0: }
else:
return param
if rank <= 0:
self.flownet.load_state_dict( self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device))) convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict( self.contextnet.load_state_dict(
@@ -173,21 +157,19 @@ class Model:
self.fusionnet.load_state_dict( self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device))) convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank=0): def save_model(self, path, rank):
if rank == 0: if rank == 0:
torch.save(self.flownet.state_dict(), torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
'{}/flownet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.contextnet.state_dict(),
'{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0 align_corners=False) * 2.0
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, flow[:, 2:4])
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet( refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt) img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
@@ -201,9 +183,8 @@ class Model:
return pred return pred
def inference(self, img0, img1): def inference(self, img0, img1):
with torch.no_grad(): imgs = torch.cat((img0, img1), 1)
imgs = torch.cat((img0, img1), 1) flow, _ = self.flownet(imgs)
flow, _ = self.flownet(imgs)
return self.predict(imgs, flow, training=False) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
@@ -222,10 +203,14 @@ class Model:
loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs( loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach() merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0 loss_cons = 0
for i in range(3): for i in range(3):
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(-flow_list[i], flow_gt[:, 2:4], 1) loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01 loss_cons = loss_cons.mean() * 0.01
else: else:
loss_cons = torch.tensor([0]) loss_cons = torch.tensor([0])

View File

@@ -135,7 +135,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@@ -188,11 +188,9 @@ class Model:
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, -flow)
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
@@ -209,10 +207,10 @@ class Model:
else: else:
return pred return pred
def inference(self, img0, img1, UHD=False): def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD) flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False, UHD=UHD) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups:

View File

@@ -120,7 +120,7 @@ class Model:
self.optimG = AdamW(itertools.chain( self.optimG = AdamW(itertools.chain(
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-4)
self.schedulerG = optim.lr_scheduler.CyclicLR( self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
@@ -173,11 +173,9 @@ class Model:
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path)) torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path)) torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2]) c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4]) c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
@@ -194,10 +192,10 @@ class Model:
else: else:
return pred return pred
def inference(self, img0, img1, UHD=False): def inference(self, img0, img1, scale=1.0):
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, _ = self.flownet(imgs, UHD) flow, _ = self.flownet(imgs, scale)
return self.predict(imgs, flow, training=False, UHD=UHD) return self.predict(imgs, flow, training=False)
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None): def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups: for param_group in self.optimG.param_groups:

View File

@@ -0,0 +1,249 @@
import torch
import torch.nn as nn
import numpy as np
from torch.optim import AdamW
import torch.optim as optim
import itertools
from model.warplayer import warp
from torch.nn.parallel import DistributedDataParallel as DDP
from model.IFNet_HDv3 import *
import torch.nn.functional as F
from model.loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes)
)
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes)
)
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True),
)
class Conv2(nn.Module):
def __init__(self, in_planes, out_planes, stride=2):
super(Conv2, self).__init__()
self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv(out_planes, out_planes, 3, 1, 1)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
c = 32
class ContextNet(nn.Module):
def __init__(self):
super(ContextNet, self).__init__()
self.conv0 = Conv2(3, c)
self.conv1 = Conv2(c, c)
self.conv2 = Conv2(c, 2*c)
self.conv3 = Conv2(2*c, 4*c)
self.conv4 = Conv2(4*c, 8*c)
def forward(self, x, flow):
x = self.conv0(x)
x = self.conv1(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False) * 0.5
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
class FusionNet(nn.Module):
def __init__(self):
super(FusionNet, self).__init__()
self.conv0 = Conv2(10, c)
self.down0 = Conv2(c, 2*c)
self.down1 = Conv2(4*c, 4*c)
self.down2 = Conv2(8*c, 8*c)
self.down3 = Conv2(16*c, 16*c)
self.up0 = deconv(32*c, 8*c)
self.up1 = deconv(16*c, 4*c)
self.up2 = deconv(8*c, 2*c)
self.up3 = deconv(4*c, c)
self.conv = nn.ConvTranspose2d(c, 4, 4, 2, 1)
def forward(self, img0, img1, flow, c0, c1, flow_gt):
warped_img0 = warp(img0, flow[:, :2])
warped_img1 = warp(img1, flow[:, 2:4])
if flow_gt == None:
warped_img0_gt, warped_img1_gt = None, None
else:
warped_img0_gt = warp(img0, flow_gt[:, :2])
warped_img1_gt = warp(img1, flow_gt[:, 2:4])
x = self.conv0(torch.cat((warped_img0, warped_img1, flow), 1))
s0 = self.down0(x)
s1 = self.down1(torch.cat((s0, c0[0], c1[0]), 1))
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model:
def __init__(self, local_rank=-1):
self.flownet = IFNet()
self.contextnet = ContextNet()
self.fusionnet = FusionNet()
self.device()
self.optimG = AdamW(itertools.chain(
self.flownet.parameters(),
self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE()
self.ter = Ternary()
self.sobel = SOBEL()
if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[
local_rank], output_device=local_rank)
self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self):
self.flownet.train()
self.contextnet.train()
self.fusionnet.train()
def eval(self):
self.flownet.eval()
self.contextnet.eval()
self.fusionnet.eval()
def device(self):
self.flownet.to(device)
self.contextnet.to(device)
self.fusionnet.to(device)
def load_model(self, path, rank):
def convert(param):
if rank == -1:
return {
k.replace("module.", ""): v
for k, v in param.items()
if "module." in k
}
else:
return param
if rank <= 0:
self.flownet.load_state_dict(
convert(torch.load('{}/flownet.pkl'.format(path), map_location=device)))
self.contextnet.load_state_dict(
convert(torch.load('{}/contextnet.pkl'.format(path), map_location=device)))
self.fusionnet.load_state_dict(
convert(torch.load('{}/unet.pkl'.format(path), map_location=device)))
def save_model(self, path, rank):
if rank == 0:
torch.save(self.flownet.state_dict(), '{}/flownet.pkl'.format(path))
torch.save(self.contextnet.state_dict(), '{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None, UHD=False):
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
if UHD:
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False) * 2.0
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res
pred = torch.clamp(pred, 0, 1)
if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else:
return pred
def inference(self, img0, img1, UHD=False):
imgs = torch.cat((img0, img1), 1)
scale_list = [8, 4, 2]
flow, _ = self.flownet(imgs, scale_list)
res = self.predict(imgs, flow, training=False, UHD=False)
return res
def update(self, imgs, gt, learning_rate=0, mul=1, training=True, flow_gt=None):
for param_group in self.optimG.param_groups:
param_group['lr'] = learning_rate
if training:
self.train()
else:
self.eval()
flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean()
if training:
with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(4):
loss_cons += self.epe(flow_list[i][:, :2], flow_gt[:, :2], 1)
loss_cons += self.epe(flow_list[i][:, 2:4], flow_gt[:, 2:4], 1)
loss_cons = loss_cons.mean() * 0.01
else:
loss_cons = torch.tensor([0])
loss_flow = torch.abs(warped_img0 - gt).mean()
loss_mask = 1
loss_l1 = (((pred - gt) ** 2 + 1e-6) ** 0.5).mean()
if training:
self.optimG.zero_grad()
loss_G = loss_l1 + loss_cons + loss_ter
loss_G.backward()
self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
model = Model()
model.eval()
print(model.inference(imgs).shape)

View File

@@ -2,6 +2,7 @@ import torch
import numpy as np import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchvision.models as models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -79,6 +80,45 @@ class SOBEL(nn.Module):
loss = (L1X+L1Y) loss = (L1X+L1Y)
return loss return loss
class MeanShift(nn.Conv2d):
def __init__(self, data_mean, data_std, data_range=1, norm=True):
c = len(data_mean)
super(MeanShift, self).__init__(c, c, kernel_size=1)
std = torch.Tensor(data_std)
self.weight.data = torch.eye(c).view(c, c, 1, 1)
if norm:
self.weight.data.div_(std.view(c, 1, 1, 1))
self.bias.data = -1 * data_range * torch.Tensor(data_mean)
self.bias.data.div_(std)
else:
self.weight.data.mul_(std.view(c, 1, 1, 1))
self.bias.data = data_range * torch.Tensor(data_mean)
self.requires_grad = False
class VGGPerceptualLoss(torch.nn.Module):
def __init__(self, rank=0):
super(VGGPerceptualLoss, self).__init__()
blocks = []
pretrained = True
self.vgg_pretrained_features = models.vgg19(pretrained=pretrained).features
self.normalize = MeanShift([0.485, 0.456, 0.406], [0.229, 0.224, 0.225], norm=True).cuda()
for param in self.parameters():
param.requires_grad = False
def forward(self, X, Y, indices=None):
X = self.normalize(X)
Y = self.normalize(Y)
indices = [2, 7, 12, 21, 30]
weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10/1.5]
k = 0
loss = 0
for i in range(indices[-1]):
X = self.vgg_pretrained_features[i](X)
Y = self.vgg_pretrained_features[i](Y)
if (i+1) in indices:
loss += weights[k] * (X - Y.detach()).abs().mean() * 0.1
k += 1
return loss
if __name__ == '__main__': if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device) img0 = torch.zeros(3, 3, 256, 256).float().to(device)

View File

@@ -19,4 +19,4 @@ def warp(tenInput, tenFlow):
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1) tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=torch.clamp(g, -1, 1), mode='bilinear', padding_mode='zeros', align_corners=True) return torch.nn.functional.grid_sample(input=tenInput, grid=g, mode='bilinear', padding_mode='border', align_corners=True)

View File

@@ -18,6 +18,11 @@
"name": "RIFE 3.0", "name": "RIFE 3.0",
"desc": "Latest General Model", "desc": "Latest General Model",
"dir": "RIFE30", "dir": "RIFE30",
},
{
"name": "RIFE 3.1",
"desc": "Latest General Model",
"dir": "RIFE31",
"isDefault": "true" "isDefault": "true"
} }
] ]

View File

@@ -36,7 +36,6 @@ parser.add_argument('--exp', dest='exp', type=int, default=1)
args = parser.parse_args() args = parser.parse_args()
assert (not args.input is None) assert (not args.input is None)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
if torch.cuda.is_available(): if torch.cuda.is_available():
@@ -56,13 +55,21 @@ except:
print("Failed to get hardware info!") print("Failed to get hardware info!")
try: try:
try:
from model.RIFE_HDv2 import Model
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v2.x HD model.")
except:
from model.RIFE_HDv3 import Model
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
print("Loaded v3.x HD model.")
except:
from model.RIFE_HD import Model from model.RIFE_HD import Model
model = Model() model = Model()
model.load_model(os.path.join(dname, args.model), -1) model.load_model(os.path.join(dname, args.model), -1)
except: print("Loaded v1.x HD model")
from model.RIFE_HDv2 import Model
model = Model()
model.load_model(os.path.join(dname, args.model), -1)
model.eval() model.eval()
model.device() model.device()