diff --git a/model/IFNet.py b/model/IFNet.py index eff6c21..6ba503d 100644 --- a/model/IFNet.py +++ b/model/IFNet.py @@ -99,7 +99,7 @@ class IFNet(nn.Module): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() - loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean() + loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5) * loss_mask).mean() c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) diff --git a/model/IFNet_m.py b/model/IFNet_m.py index d42e32c..29bed16 100644 --- a/model/IFNet_m.py +++ b/model/IFNet_m.py @@ -100,7 +100,7 @@ class IFNet_m(nn.Module): merged[i] = merged[i][0] * mask_list[i] + merged[i][1] * (1 - mask_list[i]) if gt.shape[1] == 3: loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach() - loss_distill += ((flow_teacher.detach() - flow_list[i]).abs() * loss_mask).mean() + loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5) * loss_mask).mean() c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1) diff --git a/train.py b/train.py index 93a4dc4..5f67dda 100644 --- a/train.py +++ b/train.py @@ -21,10 +21,10 @@ log_path = 'train_log' def get_learning_rate(step): if step < 2000: mul = step / 2000. - return 3e-4 * mul + return 1e-4 * mul else: mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5 - return (3e-4 - 3e-5) * mul + 3e-5 + return (1e-4 - 1e-5) * mul + 1e-5 def flow2rgb(flow_map_np): h, w, _ = flow_map_np.shape