diff --git a/model/IFNet_HDv2.py b/model/IFNet_HDv2.py index 67658f9..f8f7de3 100644 --- a/model/IFNet_HDv2.py +++ b/model/IFNet_HDv2.py @@ -42,14 +42,14 @@ class IFBlock(nn.Module): def forward(self, x, scale=1.0): scale = self.scale / scale if scale != 1.0: - x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", + x = F.interpolate(x, scale_factor=1. / scale, mode="bilinear", align_corners=False) x = self.conv0(x) x = self.convblock(x) x = self.conv1(x) flow = x if scale != 1.0: - flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", + flow = F.interpolate(flow, scale_factor=scale, mode="bilinear", align_corners=False) return flow