Update RIFE.py

This commit is contained in:
hzwer
2022-07-21 14:25:24 +08:00
committed by GitHub
parent 9c75be5ce6
commit 9e10b63db3

View File

@@ -53,7 +53,9 @@ class Model:
if rank == 0: if rank == 0:
torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path))
def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5):
for i in range(3):
scale_list[i] = scale_list[i] * 1.0 / scale
imgs = torch.cat((img0, img1), 1) imgs = torch.cat((img0, img1), 1)
flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep)
if TTA == False: if TTA == False: