mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-23 20:09:40 +01:00
Support old Pytorch version
This commit is contained in:
@@ -79,7 +79,7 @@ class IFBlock(nn.Module):
|
||||
flow = self.up(x)
|
||||
if self.scale != 1:
|
||||
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False)
|
||||
align_corners=False)
|
||||
return flow
|
||||
|
||||
|
||||
@@ -92,7 +92,7 @@ class IFNet(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False)
|
||||
align_corners=False)
|
||||
flow0 = self.block0(x)
|
||||
F1 = flow0
|
||||
warped_img0 = warp(x[:, :3], F1)
|
||||
|
||||
@@ -74,15 +74,15 @@ class ContextNet(nn.Module):
|
||||
f1 = warp(x, flow)
|
||||
x = self.conv2(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False) * 0.5
|
||||
align_corners=False) * 0.5
|
||||
f2 = warp(x, flow)
|
||||
x = self.conv3(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False) * 0.5
|
||||
align_corners=False) * 0.5
|
||||
f3 = warp(x, flow)
|
||||
x = self.conv4(x)
|
||||
flow = F.interpolate(flow, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False) * 0.5
|
||||
align_corners=False) * 0.5
|
||||
f4 = warp(x, flow)
|
||||
return [f1, f2, f3, f4]
|
||||
|
||||
@@ -187,7 +187,7 @@ class Model:
|
||||
c0 = self.contextnet(img0, flow)
|
||||
c1 = self.contextnet(img1, -flow)
|
||||
flow = F.interpolate(flow, scale_factor=2.0, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False) * 2.0
|
||||
align_corners=False) * 2.0
|
||||
refine_output, warped_img0, warped_img1, warped_img0_gt, warped_img1_gt = self.fusionnet(
|
||||
img0, img1, flow, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
@@ -222,9 +222,9 @@ class Model:
|
||||
loss_mask = torch.abs(
|
||||
merged_img - gt).sum(1, True).float().detach()
|
||||
loss_mask = F.interpolate(loss_mask, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False).detach()
|
||||
align_corners=False).detach()
|
||||
flow_gt = (F.interpolate(flow_gt, scale_factor=0.5, mode="bilinear",
|
||||
align_corners=False, recompute_scale_factor=False) * 0.5).detach()
|
||||
align_corners=False) * 0.5).detach()
|
||||
loss_cons = 0
|
||||
for i in range(3):
|
||||
loss_cons += self.epe(flow_list[i], flow_gt[:, :2], 1)
|
||||
|
||||
Reference in New Issue
Block a user