mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 00:17:46 +01:00
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
import os
|
|
import sys
|
|
sys.path.append('.')
|
|
import cv2
|
|
import math
|
|
import torch
|
|
import argparse
|
|
import numpy as np
|
|
from torch.nn import functional as F
|
|
from model.pytorch_msssim import ssim_matlab
|
|
from model.RIFE import Model
|
|
from skimage.color import rgb2yuv, yuv2rgb
|
|
from yuv_frame_io import YUV_Read,YUV_Write
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
model = Model(arbitrary=True)
|
|
model.load_model('RIFE_m_train_log')
|
|
model.eval()
|
|
model.device()
|
|
|
|
name_list = [
|
|
('HD_dataset/HD720p_GT/parkrun_1280x720_50.yuv', 720, 1280),
|
|
('HD_dataset/HD720p_GT/shields_1280x720_60.yuv', 720, 1280),
|
|
('HD_dataset/HD720p_GT/stockholm_1280x720_60.yuv', 720, 1280),
|
|
('HD_dataset/HD1080p_GT/BlueSky.yuv', 1080, 1920),
|
|
('HD_dataset/HD1080p_GT/Kimono1_1920x1080_24.yuv', 1080, 1920),
|
|
('HD_dataset/HD1080p_GT/ParkScene_1920x1080_24.yuv', 1080, 1920),
|
|
('HD_dataset/HD1080p_GT/sunflower_1080p25.yuv', 1080, 1920),
|
|
('HD_dataset/HD544p_GT/Sintel_Alley2_1280x544.yuv', 544, 1280),
|
|
('HD_dataset/HD544p_GT/Sintel_Market5_1280x544.yuv', 544, 1280),
|
|
('HD_dataset/HD544p_GT/Sintel_Temple1_1280x544.yuv', 544, 1280),
|
|
('HD_dataset/HD544p_GT/Sintel_Temple2_1280x544.yuv', 544, 1280),
|
|
]
|
|
def inference(I0, I1, pad, multi=2, arbitrary=True):
|
|
img = [I0, I1]
|
|
if not arbitrary:
|
|
for i in range(multi):
|
|
res = [I0]
|
|
for j in range(len(img) - 1):
|
|
res.append(model.inference(img[j], img[j + 1]))
|
|
res.append(img[j + 1])
|
|
img = res
|
|
else:
|
|
img = [I0]
|
|
p = 2**multi
|
|
for i in range(p-1):
|
|
img.append(model.inference(I0, I1, timestep=(i+1)*(1./p)))
|
|
img.append(I1)
|
|
for i in range(len(img)):
|
|
img[i] = img[i][0][:, pad: -pad]
|
|
return img[1: -1]
|
|
|
|
tot = []
|
|
for data in name_list:
|
|
psnr_list = []
|
|
name = data[0]
|
|
h = data[1]
|
|
w = data[2]
|
|
if 'yuv' in name:
|
|
Reader = YUV_Read(name, h, w, toRGB=True)
|
|
else:
|
|
Reader = cv2.VideoCapture(name)
|
|
_, lastframe = Reader.read()
|
|
# fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
|
|
# video = cv2.VideoWriter(name + '.mp4', fourcc, 30, (w, h))
|
|
for index in range(0, 100, 4):
|
|
gt = []
|
|
if 'yuv' in name:
|
|
IMAGE1, success1 = Reader.read(index)
|
|
IMAGE2, success2 = Reader.read(index + 4)
|
|
if not success2:
|
|
break
|
|
for i in range(1, 4):
|
|
tmp, _ = Reader.read(index + i)
|
|
gt.append(tmp)
|
|
else:
|
|
print('Not Implement')
|
|
I0 = torch.from_numpy(np.transpose(IMAGE1, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0)
|
|
I1 = torch.from_numpy(np.transpose(IMAGE2, (2,0,1)).astype("float32") / 255.).cuda().unsqueeze(0)
|
|
|
|
if h == 720:
|
|
pad = 24
|
|
elif h == 1080:
|
|
pad = 4
|
|
else:
|
|
pad = 16
|
|
pader = torch.nn.ReplicationPad2d([0, 0, pad, pad])
|
|
I0 = pader(I0)
|
|
I1 = pader(I1)
|
|
with torch.no_grad():
|
|
pred = inference(I0, I1, pad)
|
|
for i in range(4 - 1):
|
|
out = (np.round(pred[i].detach().cpu().numpy().transpose(1, 2, 0) * 255)).astype('uint8')
|
|
if 'yuv' in name:
|
|
diff_rgb = 128.0 + rgb2yuv(gt[i] / 255.)[:, :, 0] * 255 - rgb2yuv(out / 255.)[:, :, 0] * 255
|
|
mse = np.mean((diff_rgb - 128.0) ** 2)
|
|
PIXEL_MAX = 255.0
|
|
psnr = 20 * math.log10(PIXEL_MAX / math.sqrt(mse))
|
|
else:
|
|
print('Not Implement')
|
|
psnr_list.append(psnr)
|
|
print(np.mean(psnr_list))
|
|
tot.append(np.mean(psnr_list))
|
|
|
|
print('PSNR: {}(544*1280), {}(720p), {}(1080p)'.format(np.mean(tot[7:11]), np.mean(tot[:3]), np.mean(tot[3:7])))
|