mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
38 lines
1.3 KiB
Python
38 lines
1.3 KiB
Python
import os
|
|
import sys
|
|
sys.path.append('.')
|
|
import cv2
|
|
import math
|
|
import torch
|
|
import argparse
|
|
import numpy as np
|
|
from torch.nn import functional as F
|
|
from model.pytorch_msssim import ssim_matlab
|
|
from model.RIFE import Model
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model = Model()
|
|
model.load_model('train_log')
|
|
model.eval()
|
|
model.device()
|
|
|
|
name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
|
|
IE_list = []
|
|
for i in name:
|
|
i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255.
|
|
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255.
|
|
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
|
|
h, w = i0.shape[1], i0.shape[2]
|
|
imgs = torch.zeros([1, 6, 480, 640]).to(device)
|
|
ph = (480 - h) // 2
|
|
pw = (640 - w) // 2
|
|
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float().to(device)
|
|
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float().to(device)
|
|
I0 = imgs[:, :3]
|
|
I2 = imgs[:, 3:]
|
|
pred = model.inference(I0, I2)
|
|
out = pred[0].detach().cpu().numpy().transpose(1, 2, 0)
|
|
out = np.round(out[:h, :w] * 255)
|
|
IE_list.append(np.abs((out - gt * 1.0)).mean())
|
|
print(np.mean(IE_list))
|