diff --git a/model/IFNet_HDv2.py b/model/IFNet_HDv2.py index c8243a8..67658f9 100644 --- a/model/IFNet_HDv2.py +++ b/model/IFNet_HDv2.py @@ -39,15 +39,16 @@ class IFBlock(nn.Module): ) self.conv1 = nn.ConvTranspose2d(2*c, 4, 4, 2, 1) - def forward(self, x): - if self.scale != 1: + def forward(self, x, scale=1.0): + scale = self.scale / scale + if scale != 1.0: x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", align_corners=False) x = self.conv0(x) x = self.convblock(x) x = self.conv1(x) flow = x - if self.scale != 1: + if scale != 1.0: flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", align_corners=False) return flow @@ -62,27 +63,23 @@ class IFNet(nn.Module): self.block3 = IFBlock(10, scale=1, c=48) def forward(self, x, scale=1.0): - if scale != 1.0: - x = F.interpolate(x, scale_factor=scale, mode="bilinear", align_corners=False) - flow0 = self.block0(x) + flow0 = self.block0(x, scale) F1 = flow0 F1_large = F.interpolate(F1, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F1_large[:, :2]) warped_img1 = warp(x[:, 3:], F1_large[:, 2:4]) - flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1)) + flow1 = self.block1(torch.cat((warped_img0, warped_img1, F1_large), 1), scale) F2 = (flow0 + flow1) F2_large = F.interpolate(F2, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F2_large[:, :2]) warped_img1 = warp(x[:, 3:], F2_large[:, 2:4]) - flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1)) + flow2 = self.block2(torch.cat((warped_img0, warped_img1, F2_large), 1), scale) F3 = (flow0 + flow1 + flow2) F3_large = F.interpolate(F3, scale_factor=2.0, mode="bilinear", align_corners=False, recompute_scale_factor=False) * 2.0 warped_img0 = warp(x[:, :3], F3_large[:, :2]) warped_img1 = warp(x[:, 3:], F3_large[:, 2:4]) - flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1)) + flow3 = self.block3(torch.cat((warped_img0, warped_img1, F3_large), 1), scale) F4 = (flow0 + flow1 + flow2 + flow3) - if scale != 1.0: - F4 = F.interpolate(F4, scale_factor=1 / scale, mode="bilinear", align_corners=False) / scale return F4, [F1, F2, F3, F4] if __name__ == '__main__':