mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Update RIFE.py
This commit is contained in:
@@ -53,7 +53,9 @@ class Model:
|
|||||||
if rank == 0:
|
if rank == 0:
|
||||||
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
|
||||||
|
|
||||||
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5):
|
def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5):
|
||||||
|
for i in range(3):
|
||||||
|
scale_list[i] = scale_list[i] * 1.0 / scale
|
||||||
imgs = torch.cat((img0, img1), 1)
|
imgs = torch.cat((img0, img1), 1)
|
||||||
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)
|
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)
|
||||||
if TTA == False:
|
if TTA == False:
|
||||||
|
|||||||
Reference in New Issue
Block a user