This commit is contained in:
hzwer
2022-03-18 10:48:07 +08:00
parent f114857833
commit d0a5d1fb16
2 changed files with 2 additions and 2 deletions

View File

@@ -99,7 +99,7 @@ class IFNet(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5) * loss_mask).mean()
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)

View File

@@ -100,7 +100,7 @@ class IFNet_m(nn.Module):
merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i])
if gt.shape[1] == 3:
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5) * loss_mask).mean()
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
c0 = self.contextnet(img0, flow[:, :2])
c1 = self.contextnet(img1, flow[:, 2:4])
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)