mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
Fix warmup
This commit is contained in:
@@ -50,22 +50,6 @@ class Model:
|
||||
if rank == 0:
|
||||
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
||||
|
||||
'''
|
||||
def predict(self, imgs, flow, merged, training=True, flow_gt=None):
|
||||
img0 = imgs[:, :3]
|
||||
img1 = imgs[:, 3:]
|
||||
c0 = self.contextnet(img0, flow[:, :2])
|
||||
c1 = self.contextnet(img1, flow[:, 2:4])
|
||||
refine_output = self.unet(img0, img1, flow, merged, c0, c1, flow_gt)
|
||||
res = torch.sigmoid(refine_output[:, :3]) * 2 - 1
|
||||
pred = merged + res
|
||||
pred = torch.clamp(pred, 0, 1)
|
||||
if training:
|
||||
return pred, merged
|
||||
else:
|
||||
return pred
|
||||
'''
|
||||
|
||||
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False):
|
||||
imgs = torch.cat((img0, img1), 1)
|
||||
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list)
|
||||
|
||||
Reference in New Issue
Block a user