Add Vimeo90K benchmark

This commit is contained in:
hzwer
2020-11-14 16:56:09 +08:00
parent 4f7c00f589
commit ab91f19a49

34
Vimeo90K_benchmark.py Normal file
View File

@@ -0,0 +1,34 @@
import os
import cv2
import math
import torch
import argparse
import numpy as np
from torch.nn import functional as F
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.eval()
model.device()
path = 'vimeo_interp_test/'
f = open(path + 'tri_testlist.txt', 'r')
psnr_list = []
for i in f:
name = str(i).strip()
if(len(name) <= 1):
continue
print(path + 'target/' + name + '/im1.png')
I0 = cv2.imread(path + 'target/' + name + '/im1.png')
I1 = cv2.imread(path + 'target/' + name + '/im2.png')
I2 = cv2.imread(path + 'target/' + name + '/im3.png')
I0 = (torch.tensor(I0.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
I2 = (torch.tensor(I2.transpose(2, 0, 1)).to(device) / 255.).unsqueeze(0)
mid = model.inference(I0, I2)[0]
I1 = torch.tensor(I1.transpose(2, 0, 1)).to(device) / 255.
psnr = -10 * math.log10(torch.mean((I1 - mid) * (I1 - mid)).cpu().data)
psnr_list.append(psnr)
print(np.mean(psnr_list))