diff --git a/README.md b/README.md index 3be78ab..e42ffb2 100644 --- a/README.md +++ b/README.md @@ -90,11 +90,13 @@ First you should download [RIFE model reported by our paper](https://drive.googl We will release our training and benchmark validation code soon. -**Vimeo90K** -Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test +**Vimeo90K**: Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test +**MiddleBury**: Download [MiddleBury OTHER dataset](https://vision.middlebury.edu/flow/data/) at ./other-data and ./other-gt-interp ``` -$ python3 benchmark/Vimeo90K_benchmark.py +$ python3 benchmark/Vimeo90K.py (Final result: "Avg PSNR: 35.695 SSIM: 0.9788") +$ python3 benchmark/MiddelBury_Other.py +(Final result: "2.058") ``` ## Citation diff --git a/benchmark/MiddleBury_Other.py b/benchmark/MiddleBury_Other.py new file mode 100644 index 0000000..18b4e16 --- /dev/null +++ b/benchmark/MiddleBury_Other.py @@ -0,0 +1,37 @@ +import os +import sys +sys.path.append('.') +import cv2 +import math +import torch +import argparse +import numpy as np +from torch.nn import functional as F +from pytorch_msssim import ssim_matlab +from model.RIFE import Model +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +model = Model() +model.load_model('./train_log') +model.eval() +model.device() + +name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking'] +IE_list = [] +for i in name: + i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255. + i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255. + gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i)) + h, w = i0.shape[1], i0.shape[2] + imgs = torch.zeros([1, 6, 480, 640]) + ph = (480 - h) // 2 + pw = (640 - w) // 2 + imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float() + imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float() + I0 = imgs[:, :3] + I2 = imgs[:, 3:] + pred = model.inference(I0, I2) + out = pred[0].cpu().numpy().transpose(1, 2, 0) + out = np.round(out[:h, :w] * 255) + IE_list.append(np.abs((out - gt * 1.0)).mean()) + print(np.mean(IE_list)) diff --git a/benchmark/Vimeo90K_benchmark.py b/benchmark/Vimeo90K.py similarity index 100% rename from benchmark/Vimeo90K_benchmark.py rename to benchmark/Vimeo90K.py