diff --git a/benchmark/Vimeo90K.py b/benchmark/Vimeo90K.py index da7d0fe..c984b4c 100644 --- a/benchmark/Vimeo90K.py +++ b/benchmark/Vimeo90K.py @@ -8,6 +8,7 @@ import argparse import numpy as np from torch.nn import functional as F from pytorch_msssim import ssim_matlab +# from model.RIFE2F15C import Model from model.RIFE import Model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -31,8 +32,8 @@ for i in f: I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0) mid = model.inference(I0, I2)[0] - ssim = ssim_matlab(torch.tensor(I1.transpose(2, 0, 1)).to(device).unsqueeze(0) / 255., mid.unsqueeze(0)).cpu().numpy() - mid = np.round((mid * 255).cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255. + ssim = ssim_matlab(torch.tensor(I1.transpose(2, 0, 1)).to(device).unsqueeze(0) / 255., torch.round(mid * 255).unsqueeze(0) / 255.).detach().cpu().numpy() + mid = np.round((mid * 255).detach().cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255. I1 = I1 / 255. psnr = -10 * math.log10(((I1 - mid) * (I1 - mid)).mean()) psnr_list.append(psnr)