diff --git a/model/RIFE.py b/model/RIFE.py index d3d7e6f..6abf090 100644 --- a/model/RIFE.py +++ b/model/RIFE.py @@ -53,7 +53,9 @@ class Model: if rank == 0: torch.save(self.flownet.state_dict(),'{}/flownet.pkl'.format(path)) - def inference(self, img0, img1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): + def inference(self, img0, img1, scale=1, scale_list=[4, 2, 1], TTA=False, timestep=0.5): + for i in range(3): + scale_list[i] = scale_list[i] * 1.0 / scale imgs = torch.cat((img0, img1), 1) flow, mask, merged, flow_teacher, merged_teacher, loss_distill = self.flownet(imgs, scale_list, timestep=timestep) if TTA == False: