diff --git a/Vimeo90K_benchmark.py b/Vimeo90K_benchmark.py index 2b00706..d5c9d54 100644 --- a/Vimeo90K_benchmark.py +++ b/Vimeo90K_benchmark.py @@ -33,7 +33,7 @@ for i in f: mid = np.round((mid * 255).cpu().numpy()).astype('uint8').transpose(1, 2, 0) / 255. I1 = I1 / 255. psnr = -10 * math.log10(((I1 - mid) * (I1 - mid)).mean()) - ssim = ssim_matlab(torch.tensor(I1).unsqueeze(0).float(), torch.tensor(mid).unsqueeze(0).float()) + ssim = ssim_matlab(torch.tensor(I1).unsqueeze(0).float().to(device), torch.tensor(mid).unsqueeze(0).float().to(device)).cpu().numpy() psnr_list.append(psnr) ssim_list.append(ssim) print(np.mean(psnr_list), np.mean(ssim_list))