mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Update IFNet_m.py
This commit is contained in:
@@ -60,7 +60,7 @@ class IFNet_m(nn.Module):
|
|||||||
self.contextnet = Contextnet()
|
self.contextnet = Contextnet()
|
||||||
self.unet = Unet()
|
self.unet = Unet()
|
||||||
|
|
||||||
def forward(self, x, scale=[4,2,1], timestep=0.5):
|
def forward(self, x, scale=[4,2,1], timestep=0.5, returnflow=False):
|
||||||
timestep = (x[:, :1].clone() * 0 + 1) * timestep
|
timestep = (x[:, :1].clone() * 0 + 1) * timestep
|
||||||
img0 = x[:, :3]
|
img0 = x[:, :3]
|
||||||
img1 = x[:, 3:6]
|
img1 = x[:, 3:6]
|
||||||
@@ -101,9 +101,12 @@ class IFNet_m(nn.Module):
|
|||||||
if gt.shape[1] == 3:
|
if gt.shape[1] == 3:
|
||||||
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
|
loss_mask = ((merged[i] - gt).abs().mean(1, True) > (merged_teacher - gt).abs().mean(1, True) + 0.01).float().detach()
|
||||||
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
loss_distill += (((flow_teacher.detach() - flow_list[i]) ** 2).mean(1, True) ** 0.5 * loss_mask).mean()
|
||||||
c0 = self.contextnet(img0, flow[:, :2])
|
if returnflow:
|
||||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
return flow
|
||||||
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
else:
|
||||||
res = tmp[:, :3] * 2 - 1
|
c0 = self.contextnet(img0, flow[:, :2])
|
||||||
merged[2] = torch.clamp(merged[2] + res, 0, 1)
|
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||||
|
tmp = self.unet(img0, img1, warped_img0, warped_img1, mask, flow, c0, c1)
|
||||||
|
res = tmp[:, :3] * 2 - 1
|
||||||
|
merged[2] = torch.clamp(merged[2] + res, 0, 1)
|
||||||
return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill
|
return flow_list, mask_list[2], merged, flow_teacher, merged_teacher, loss_distill
|
||||||
|
|||||||
Reference in New Issue
Block a user