Remove recompute_scale_factor

This commit is contained in:
hzwer
2020-11-18 13:55:27 +08:00
parent d45cdea9b1
commit b5a2edb04c
3 changed files with 8 additions and 8 deletions

View File

@@ -67,7 +67,7 @@ class IFBlock(nn.Module):
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
x = self.conv0(x)
x = self.res0(x)
x = self.res1(x)

View File

@@ -67,7 +67,7 @@ class IFBlock(nn.Module):
def forward(self, x):
if self.scale != 1:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
x = self.conv0(x)
x = self.res0(x)
x = self.res1(x)
@@ -79,7 +79,7 @@ class IFBlock(nn.Module):
flow = x # self.up(x)
if self.scale != 1:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
return flow
@@ -92,7 +92,7 @@ class IFNet(nn.Module):
def forward(self, x):
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False)
align_corners=False)
flow0 = self.block0(x)
F1 = flow0
warped_img0 = warp(x[:, :3], F1)

View File

@@ -74,15 +74,15 @@ class ContextNet(nn.Module):
f1 = warp(x, flow)
x = self.conv2(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f2 = warp(x, flow)
x = self.conv3(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f3 = warp(x, flow)
x = self.conv4(x)
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 0.5
align_corners=False) * 0.5
f4 = warp(x, flow)
return [f1, f2, f3, f4]
@@ -185,7 +185,7 @@ class Model:
img0 = imgs[:, :3]
img1 = imgs[:, 3:]
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
align_corners=False, recompute_scale_factor=False) * 2.0
align_corners=False) * 2.0
c0 = self.contextnet(img0, flow)
c1 = self.contextnet(img1, -flow)
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(