Clean code

This commit is contained in:
hzwer
2020-11-12 19:57:29 +08:00
parent 20d6abefb8
commit 9ac61eac57
5 changed files with 121 additions and 65 deletions

2
.gitignore vendored
View File

@@ -1,3 +1,3 @@
*.pyc *.pyc
*.py~ *.py~
*.py#

View File

@@ -1,21 +1,19 @@
import torch import torch
import numpy as np
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from warplayer import warp from warplayer import warp
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_planes),
nn.PReLU(out_planes)
)
def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv_wo_act(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False), padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(out_planes), nn.BatchNorm2d(out_planes),
) )
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
@@ -25,13 +23,15 @@ def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=1): def __init__(self, in_planes, out_planes, stride=1):
super(ResBlock, self).__init__() super(ResBlock, self).__init__()
if in_planes == out_planes and stride == 1: if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity() self.conv0 = nn.Identity()
else: else:
self.conv0 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False) self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1) self.conv2 = conv_wo_act(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1) self.relu1 = nn.PReLU(1)
@@ -49,6 +49,7 @@ class ResBlock(nn.Module):
x = self.relu2(x * w + y) x = self.relu2(x * w + y)
return x return x
class IFBlock(nn.Module): class IFBlock(nn.Module):
def __init__(self, in_planes, scale=1, c=64): def __init__(self, in_planes, scale=1, c=64):
super(IFBlock, self).__init__() super(IFBlock, self).__init__()
@@ -65,7 +66,8 @@ class IFBlock(nn.Module):
def forward(self, x): def forward(self, x):
if self.scale != 1: if self.scale != 1:
x = F.interpolate(x, scale_factor= 1. / self.scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
x = self.conv0(x) x = self.conv0(x)
x = self.res0(x) x = self.res0(x)
x = self.res1(x) x = self.res1(x)
@@ -76,9 +78,11 @@ class IFBlock(nn.Module):
x = self.conv1(x) x = self.conv1(x)
flow = self.up(x) flow = self.up(x)
if self.scale != 1: if self.scale != 1:
flow = F.interpolate(flow, scale_factor= self.scale, mode="bilinear", align_corners=False, recompute_scale_factor=False) flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
return flow return flow
class IFNet(nn.Module): class IFNet(nn.Module):
def __init__(self): def __init__(self):
super(IFNet, self).__init__() super(IFNet, self).__init__()
@@ -87,7 +91,8 @@ class IFNet(nn.Module):
self.block2 = IFBlock(8, scale=1, c=64) self.block2 = IFBlock(8, scale=1, c=64)
def forward(self, x): def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
flow0 = self.block0(x) flow0 = self.block0(x)
F1 = flow0 F1 = flow0
warped_img0 = warp(x[:, :3], F1) warped_img0 = warp(x[:, :3], F1)
@@ -99,3 +104,12 @@ class IFNet(nn.Module):
flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1)) flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2), 1))
F3 = (flow0 + flow1 + flow2) F3 = (flow0 + flow1 + flow2)
return F3, [F1, F2, F3] return F3, [F1, F2, F3]
if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1)
flownet = IFNet()
flow, _ = flownet(imgs)
print(flow.shape)

View File

@@ -12,24 +12,29 @@ from loss import *
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): def conv_woact(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1):
return nn.Sequential( return nn.Sequential(
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=True), padding=padding, dilation=dilation, bias=True),
) )
def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1): def deconv(in_planes, out_planes, kernel_size=4, stride=2, padding=1):
return nn.Sequential( return nn.Sequential(
torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes, kernel_size=4, stride=2, padding=1, bias=True), torch.nn.ConvTranspose2d(in_channels=in_planes, out_channels=out_planes,
kernel_size=4, stride=2, padding=1, bias=True),
nn.PReLU(out_planes) nn.PReLU(out_planes)
) )
class ResBlock(nn.Module): class ResBlock(nn.Module):
def __init__(self, in_planes, out_planes, stride=2): def __init__(self, in_planes, out_planes, stride=2):
@@ -37,7 +42,8 @@ class ResBlock(nn.Module):
if in_planes == out_planes and stride == 1: if in_planes == out_planes and stride == 1:
self.conv0 = nn.Identity() self.conv0 = nn.Identity()
else: else:
self.conv0 = nn.Conv2d(in_planes, out_planes, 3, stride, 1, bias=False) self.conv0 = nn.Conv2d(in_planes, out_planes,
3, stride, 1, bias=False)
self.conv1 = conv(in_planes, out_planes, 3, stride, 1) self.conv1 = conv(in_planes, out_planes, 3, stride, 1)
self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1) self.conv2 = conv_woact(out_planes, out_planes, 3, 1, 1)
self.relu1 = nn.PReLU(1) self.relu1 = nn.PReLU(1)
@@ -52,9 +58,13 @@ class ResBlock(nn.Module):
w = x.mean(3, True).mean(2, True) w = x.mean(3, True).mean(2, True)
w = self.relu1(self.fc1(w)) w = self.relu1(self.fc1(w))
w = torch.sigmoid(self.fc2(w)) w = torch.sigmoid(self.fc2(w))
x = self.relu2(x * w + y) x = self.relu2(x * w + y)
return x return x
c = 16 c = 16
class ContextNet(nn.Module): class ContextNet(nn.Module):
def __init__(self): def __init__(self):
super(ContextNet, self).__init__() super(ContextNet, self).__init__()
@@ -62,21 +72,25 @@ class ContextNet(nn.Module):
self.conv2 = ResBlock(c, 2*c) self.conv2 = ResBlock(c, 2*c)
self.conv3 = ResBlock(2*c, 4*c) self.conv3 = ResBlock(2*c, 4*c)
self.conv4 = ResBlock(4*c, 8*c) self.conv4 = ResBlock(4*c, 8*c)
def forward(self, x, flow): def forward(self, x, flow):
x = self.conv1(x) x = self.conv1(x)
f1 = warp(x, flow) f1 = warp(x, flow)
x = self.conv2(x) x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
f2 = warp(x, flow) f2 = warp(x, flow)
x = self.conv3(x) x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
f3 = warp(x, flow) f3 = warp(x, flow)
x = self.conv4(x) x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5 flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
f4 = warp(x, flow) f4 = warp(x, flow)
return [f1, f2, f3, f4] return [f1, f2, f3, f4]
class FusionNet(nn.Module): class FusionNet(nn.Module):
def __init__(self): def __init__(self):
super(FusionNet, self).__init__() super(FusionNet, self).__init__()
@@ -103,12 +117,13 @@ class FusionNet(nn.Module):
s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1)) s2 = self.down2(torch.cat((s1, c0[1], c1[1]), 1))
s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1)) s3 = self.down3(torch.cat((s2, c0[2], c1[2]), 1))
x = self.up0(torch.cat((s3, c0[3], c1[3]), 1)) x = self.up0(torch.cat((s3, c0[3], c1[3]), 1))
x = self.up1(torch.cat((x, s2), 1)) x = self.up1(torch.cat((x, s2), 1))
x = self.up2(torch.cat((x, s1), 1)) x = self.up2(torch.cat((x, s1), 1))
x = self.up3(torch.cat((x, s0), 1)) x = self.up3(torch.cat((x, s0), 1))
x = self.conv(x) x = self.conv(x)
return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt return x, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
class Model: class Model:
def __init__(self, local_rank=-1): def __init__(self, local_rank=-1):
self.flownet = IFNet() self.flownet = IFNet()
@@ -119,14 +134,18 @@ class Model:
self.flownet.parameters(), self.flownet.parameters(),
self.contextnet.parameters(), self.contextnet.parameters(),
self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5) self.fusionnet.parameters()), lr=1e-6, weight_decay=1e-5)
self.schedulerG = optim.lr_scheduler.CyclicLR(self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False) self.schedulerG = optim.lr_scheduler.CyclicLR(
self.optimG, base_lr=1e-6, max_lr=1e-3, step_size_up=8000, cycle_momentum=False)
self.epe = EPE() self.epe = EPE()
self.ter = Ternary() self.ter = Ternary()
self.sobel = SOBEL() self.sobel = SOBEL()
if local_rank != -1: if local_rank != -1:
self.flownet = DDP(self.flownet, device_ids=[local_rank], output_device=local_rank) self.flownet = DDP(self.flownet, device_ids=[
self.contextnet = DDP(self.contextnet, device_ids=[local_rank], output_device=local_rank) local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[local_rank], output_device=local_rank) self.contextnet = DDP(self.contextnet, device_ids=[
local_rank], output_device=local_rank)
self.fusionnet = DDP(self.fusionnet, device_ids=[
local_rank], output_device=local_rank)
def train(self): def train(self):
self.flownet.train() self.flownet.train()
@@ -145,33 +164,40 @@ class Model:
def load_model(self, path, rank=0): def load_model(self, path, rank=0):
if rank == 0: if rank == 0:
self.flownet.load_state_dict(torch.load('{}/flownet.pkl'.format(path))) self.flownet.load_state_dict(
self.contextnet.load_state_dict(torch.load('{}/contextnet.pkl'.format(path))) torch.load('{}/flownet.pkl'.format(path)))
self.fusionnet.load_state_dict(torch.load('{}/unet.pkl'.format(path))) self.contextnet.load_state_dict(
torch.load('{}/contextnet.pkl'.format(path)))
self.fusionnet.load_state_dict(
torch.load('{}/unet.pkl'.format(path)))
def save_model(self, path, rank=0): def save_model(self, path, rank=0):
if rank == 0: if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) torch.save(self.flownet.state_dict(),
torch.save(self.contextnet.state_dict(),'{}/contextnet.pkl'.format(path)) '{}/flownet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(),'{}/unet.pkl'.format(path)) torch.save(self.contextnet.state_dict(),
'{}/contextnet.pkl'.format(path))
torch.save(self.fusionnet.state_dict(), '{}/unet.pkl'.format(path))
def predict(self, imgs, flow, training=True, flow_gt=None): def predict(self, imgs, flow, training=True, flow_gt=None):
img0 = imgs[:, :3] img0 = imgs[:, :3]
img1 = imgs[:, 3:] img1 = imgs[:, 3:]
c0 = self.contextnet(img0, flow) c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow) c1 = self.contextnet(img1, -flow)
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(img0, img1, flow, c0, c1, flow_gt) align_corners=False, recompute_scale_factor=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1 res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
mask = torch.sigmoid(refine_output[:, 3:4]) mask = torch.sigmoid(refine_output[:, 3:4])
merged_img = warped_img0 * mask + warped_img1 * (1 - mask) merged_img = warped_img0 * mask + warped_img1 * (1 - mask)
pred = merged_img + res pred = merged_img + res
pred = torch.clamp(pred, 0, 1) pred = torch.clamp(pred, 0, 1)
if training: if training:
return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt return pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt
else: else:
return pred return pred
def inference(self, imgs): def inference(self, imgs):
with torch.no_grad(): with torch.no_grad():
flow, _ = self.flownet(imgs) flow, _ = self.flownet(imgs)
@@ -182,19 +208,23 @@ class Model:
param_group['lr'] = learning_rate param_group['lr'] = learning_rate
if training: if training:
self.train() self.train()
# with torch.no_grad(): # with torch.no_grad():
# flow_gt = estimate(gt, img0) # flow_gt = estimate(gt, img0)
else: else:
self.eval() self.eval()
flow, flow_list = self.flownet(imgs) flow, flow_list = self.flownet(imgs)
pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(imgs, flow, flow_gt=flow_gt) pred, mask, merged_img, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.predict(
imgs, flow, flow_gt=flow_gt)
loss_ter = self.ter(pred, gt).mean() loss_ter = self.ter(pred, gt).mean()
if training: if training:
with torch.no_grad(): with torch.no_grad():
loss_flow = torch.abs(warped_img0_gt - gt).mean() loss_flow = torch.abs(warped_img0_gt - gt).mean()
loss_mask = torch.abs(merged_img - gt).sum(1, True).float().detach() loss_mask = torch.abs(
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False).detach() merged_img - gt).sum(1, True).float().detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 0.5).detach() loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5).detach()
loss_cons = 0 loss_cons = 0
for i in range(3): for i in range(3):
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1) loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)
@@ -212,9 +242,11 @@ class Model:
self.optimG.step() self.optimG.step()
return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask return pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, loss_mask
if __name__ == '__main__': if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device) img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device) img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
model = Model() model = Model()
model.eval() model.eval()

View File

@@ -4,9 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grid = None
Grid = {}
class EPE(nn.Module): class EPE(nn.Module):
def __init__(self): def __init__(self):
super(EPE, self).__init__() super(EPE, self).__init__()
@@ -16,12 +15,14 @@ class EPE(nn.Module):
loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5 loss_map = (loss_map.sum(1, True) + 1e-6) ** 0.5
return (loss_map * loss_mask) return (loss_map * loss_mask)
class Ternary(nn.Module): class Ternary(nn.Module):
def __init__(self): def __init__(self):
super(Ternary, self).__init__() super(Ternary, self).__init__()
patch_size = 7 patch_size = 7
out_channels = patch_size * patch_size out_channels = patch_size * patch_size
self.w = np.eye(out_channels).reshape((patch_size, patch_size, 1, out_channels)) self.w = np.eye(out_channels).reshape(
(patch_size, patch_size, 1, out_channels))
self.w = np.transpose(self.w, (3, 2, 0, 1)) self.w = np.transpose(self.w, (3, 2, 0, 1))
self.w = torch.tensor(self.w).float().to(device) self.w = torch.tensor(self.w).float().to(device)
@@ -32,10 +33,10 @@ class Ternary(nn.Module):
return transf_norm return transf_norm
def rgb2gray(self, rgb): def rgb2gray(self, rgb):
r, g, b = rgb[:, 0:1,:,:], rgb[:, 1:2,:,:], rgb[:, 2:3,:,:] r, g, b = rgb[:, 0:1, :, :], rgb[:, 1:2, :, :], rgb[:, 2:3, :, :]
gray = 0.2989 * r + 0.5870 * g + 0.1140 * b gray = 0.2989 * r + 0.5870 * g + 0.1140 * b
return gray return gray
def hamming(self, t1, t2): def hamming(self, t1, t2):
dist = (t1 - t2) ** 2 dist = (t1 - t2) ** 2
dist_norm = torch.mean(dist / (0.1 + dist), 1, True) dist_norm = torch.mean(dist / (0.1 + dist), 1, True)
@@ -46,12 +47,13 @@ class Ternary(nn.Module):
inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t) inner = torch.ones(n, 1, h - 2 * padding, w - 2 * padding).type_as(t)
mask = F.pad(inner, [padding] * 4) mask = F.pad(inner, [padding] * 4)
return mask return mask
def forward(self, img0, img1): def forward(self, img0, img1):
img0 = self.transform(self.rgb2gray(img0)) img0 = self.transform(self.rgb2gray(img0))
img1 = self.transform(self.rgb2gray(img1)) img1 = self.transform(self.rgb2gray(img1))
return self.hamming(img0, img1) * self.valid_mask(img0, 1) return self.hamming(img0, img1) * self.valid_mask(img0, 1)
class SOBEL(nn.Module): class SOBEL(nn.Module):
def __init__(self): def __init__(self):
super(SOBEL, self).__init__() super(SOBEL, self).__init__()
@@ -66,18 +68,21 @@ class SOBEL(nn.Module):
def forward(self, pred, gt): def forward(self, pred, gt):
N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3] N, C, H, W = pred.shape[0], pred.shape[1], pred.shape[2], pred.shape[3]
img_stack = torch.cat([pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0) img_stack = torch.cat(
[pred.reshape(N*C, 1, H, W), gt.reshape(N*C, 1, H, W)], 0)
sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1) sobel_stack_x = F.conv2d(img_stack, self.kernelX, padding=1)
sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1) sobel_stack_y = F.conv2d(img_stack, self.kernelY, padding=1)
pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:] pred_X, gt_X = sobel_stack_x[:N*C], sobel_stack_x[N*C:]
pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:] pred_Y, gt_Y = sobel_stack_y[:N*C], sobel_stack_y[N*C:]
L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y) L1X, L1Y = torch.abs(pred_X-gt_X), torch.abs(pred_Y-gt_Y)
loss = (L1X+L1Y) loss = (L1X+L1Y)
return loss return loss
if __name__ == '__main__': if __name__ == '__main__':
img0 = torch.zeros(3, 3, 256, 256).float().to(device) img0 = torch.zeros(3, 3, 256, 256).float().to(device)
img1 = torch.tensor(np.random.normal(0, 1, (3, 3, 256, 256))).float().to(device) img1 = torch.tensor(np.random.normal(
0, 1, (3, 3, 256, 256))).float().to(device)
ternary_loss = Ternary() ternary_loss = Ternary()
print(ternary_loss(img0, img1).shape) print(ternary_loss(img0, img1).shape)

View File

@@ -4,14 +4,19 @@ import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
backwarp_tenGrid = {} backwarp_tenGrid = {}
def warp(tenInput, tenFlow): def warp(tenInput, tenFlow):
k = (str(tenFlow.device), str(tenFlow.size())) k = (str(tenFlow.device), str(tenFlow.size()))
if k not in backwarp_tenGrid: if k not in backwarp_tenGrid:
tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) tenHorizontal = torch.linspace(-1.0, 1.0, tenFlow.shape[3]).view(
tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) 1, 1, 1, tenFlow.shape[3]).expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1)
backwarp_tenGrid[k] = torch.cat([ tenHorizontal, tenVertical ], 1).to(device) tenVertical = torch.linspace(-1.0, 1.0, tenFlow.shape[2]).view(
1, 1, tenFlow.shape[2], 1).expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3])
backwarp_tenGrid[k] = torch.cat(
[tenHorizontal, tenVertical], 1).to(device)
tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1) tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0),
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0)], 1)
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1)
return torch.nn.functional.grid_sample(input=tenInput, grid=torch.clamp(g, -1, 1), mode='bilinear', padding_mode='zeros', align_corners=True) return torch.nn.functional.grid_sample(input=tenInput, grid=torch.clamp(g, -1, 1), mode='bilinear', padding_mode='zeros', align_corners=True)