Add MiddleBury

This commit is contained in:
hzwer
2020-11-26 19:11:17 +08:00
parent 8715002aab
commit cbc3e3b005
3 changed files with 42 additions and 3 deletions

View File

@@ -90,11 +90,13 @@ First you should download [RIFE model reported by our paper](https://drive.googl
We will release our training and benchmark validation code soon.
**Vimeo90K**
Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test
**Vimeo90K**: Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test
**MiddleBury**: Download [MiddleBury OTHER dataset](https://vision.middlebury.edu/flow/data/) at ./other-data and ./other-gt-interp
```
$ python3 benchmark/Vimeo90K_benchmark.py
$ python3 benchmark/Vimeo90K.py
(Final result: "Avg PSNR: 35.695 SSIM: 0.9788")
$ python3 benchmark/MiddelBury_Other.py
(Final result: "2.058")
```
## Citation

View File

@@ -0,0 +1,37 @@
import os
import sys
sys.path.append('.')
import cv2
import math
import torch
import argparse
import numpy as np
from torch.nn import functional as F
from pytorch_msssim import ssim_matlab
from model.RIFE import Model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Model()
model.load_model('./train_log')
model.eval()
model.device()
name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
IE_list = []
for i in name:
i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255.
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255.
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
h, w = i0.shape[1], i0.shape[2]
imgs = torch.zeros([1, 6, 480, 640])
ph = (480 - h) // 2
pw = (640 - w) // 2
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float()
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float()
I0 = imgs[:, :3]
I2 = imgs[:, 3:]
pred = model.inference(I0, I2)
out = pred[0].cpu().numpy().transpose(1, 2, 0)
out = np.round(out[:h, :w] * 255)
IE_list.append(np.abs((out - gt * 1.0)).mean())
print(np.mean(IE_list))