Update IFNet_m.py

This commit is contained in:
hzwer
2022-08-16 15:09:01 +08:00
committed by GitHub
parent c6d5d70289
commit 1ea9b9aefe

View File

@@ -60,7 +60,7 @@ class IFNet_m(nn.Module):
self.contextnet = Contextnet()
self.unet = Unet()
def forward(self, x, scale=[4,2,1], timestep=0.5):
def forward(self, x, scale=[4,2,1], timestep=0.5, returnflow=False):
timestep = (x[:, :1].clone() * 0 + 1) * timestep
img0 = x[:, :3]
img1 = x[:, 3:6]
@@ -101,6 +101,9 @@ class IFNet_m(nn.Module):
if gt.shape[1] == 3:
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
if returnflow:
return flow
else:
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)