Support old Pytorch version

This commit is contained in:
hzwer
2020-11-17 10:26:11 +08:00
parent e0eb86d56b
commit 382904d0b7
2 changed files with 8 additions and 8 deletions

View File

@@ -79,7 +79,7 @@ class IFBlock(nn.Module):
flow = self.up(x)
if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
return flow
@@ -92,7 +92,7 @@ class IFNet(nn.Module):
def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
flow0 = self.block0(x)
F1 = flow0
warped_img0 = warp(x[:, :3], F1)

View File

@@ -74,15 +74,15 @@ class ContextNet(nn.Module):
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
@@ -187,7 +187,7 @@ class Model:
c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow)
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 2.0
align_corners=False) * 2.0
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
img0, img1, flow, c0, c1, flow_gt)
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
@@ -222,9 +222,9 @@ class Model:
loss_mask = torch.abs(
merged_img - gt).sum(1, True).float().detach()
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False).detach()
align_corners=False).detach()
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5).detach()
align_corners=False) * 0.5).detach()
loss_cons = 0
for i in range(3):
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)