diff --git a/inference_mp4_2x.py b/inference_mp4_2x.py index b66576a..fb3b1ad 100644 --- a/inference_mp4_2x.py +++ b/inference_mp4_2x.py @@ -36,8 +36,8 @@ while success: I1 = torch.from_numpy(np.transpose(frame, (2,0,1)).astype("float32") / 255.).to(device).unsqueeze(0) I0 = F.pad(I0, padding) I1 = F.pad(I1, padding) - if (F.interpolate(I0, (16, 16), mode='bilinear') - - F.interpolate(I1, (16, 16), mode='bilinear')).abs().mean() > 0.2: + if (F.interpolate(I0, (16, 16), mode='bilinear', recompute_scale_factor=False) + - F.interpolate(I1, (16, 16), mode='bilinear', recompute_scale_factor=False)).abs().mean() > 0.2: mid = lastframe else: mid = model.inference(I0, I1)