From 1ea9b9aefeaa0b4e622d696394ff763b51e80e2c Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Tue, 16 Aug 2022 15:09:01 +0800 Subject: [PATCH] Update IFNet_m.py --- model/IFNet_m.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/model/IFNet_m.py b/model/IFNet_m.py index 85eea88..9997b3f 100644 --- a/model/IFNet_m.py +++ b/model/IFNet_m.py @@ -60,7 +60,7 @@ class IFNet_m(nn.Module): self.contextnet = Contextnet() self.unet = Unet() - def forward(self, x, scale=[4,2,1], timestep=0.5): + def forward(self, x, scale=[4,2,1], timestep=0.5, returnflow=False): timestep = (x[:, :1].clone() * 0 + 1) * timestep img0 = x[:, :3] img1 = x[:, 3:6] @@ -101,9 +101,12 @@ class IFNet_m(nn.Module): if gt.shape[1] == 3: loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean() - c0 = self.contextnet(img0, flow[:, :2]) - c1 = self.contextnet(img1, flow[:, 2:4]) - tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) - res = tmp[:, :3] * 2 - 1 - merged[2] = torch.clamp(merged[2] + res, 0, 1) + if returnflow: + return flow + else: + c0 = self.contextnet(img0, flow[:, :2]) + c1 = self.contextnet(img1, flow[:, 2:4]) + tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) + res = tmp[:, :3] * 2 - 1 + merged[2] = torch.clamp(merged[2] + res, 0, 1) return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill