Fix scale mode

This commit is contained in:
hzwer
2021-02-27 19:54:01 +08:00
parent f75bbacbd1
commit e59b084b7d

View File

@@ -42,14 +42,14 @@ class IFBlock(nn.Module):
def forward(self, x, scale=1.0): def forward(self, x, scale=1.0):
scale = self.scale / scale scale = self.scale / scale
if scale != 1.0: if scale != 1.0:
x = F.interpolate(x, scale_factor=1. / self.scale, mode="bilinear", x = F.interpolate(x, scale_factor=1. / scale, mode="bilinear",
align_corners=False) align_corners=False)
x = self.conv0(x) x = self.conv0(x)
x = self.convblock(x) x = self.convblock(x)
x = self.conv1(x) x = self.conv1(x)
flow = x flow = x
if scale != 1.0: if scale != 1.0:
flow = F.interpolate(flow, scale_factor=self.scale, mode="bilinear", flow = F.interpolate(flow, scale_factor=scale, mode="bilinear",
align_corners=False) align_corners=False)
return flow return flow