mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 16:37:51 +01:00
Add MiddleBury
This commit is contained in:
@@ -90,11 +90,13 @@ First you should download [RIFE model reported by our paper](https://drive.googl
|
|||||||
|
|
||||||
We will release our training and benchmark validation code soon.
|
We will release our training and benchmark validation code soon.
|
||||||
|
|
||||||
**Vimeo90K**
|
**Vimeo90K**: Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test
|
||||||
Download [Vimeo90K dataset](http://toflow.csail.mit.edu/) at ./vimeo_interp_test
|
**MiddleBury**: Download [MiddleBury OTHER dataset](https://vision.middlebury.edu/flow/data/) at ./other-data and ./other-gt-interp
|
||||||
```
|
```
|
||||||
$ python3 benchmark/Vimeo90K_benchmark.py
|
$ python3 benchmark/Vimeo90K.py
|
||||||
(Final result: "Avg PSNR: 35.695 SSIM: 0.9788")
|
(Final result: "Avg PSNR: 35.695 SSIM: 0.9788")
|
||||||
|
$ python3 benchmark/MiddelBury_Other.py
|
||||||
|
(Final result: "2.058")
|
||||||
```
|
```
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|||||||
37
benchmark/MiddleBury_Other.py
Normal file
37
benchmark/MiddleBury_Other.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
sys.path.append('.')
|
||||||
|
import cv2
|
||||||
|
import math
|
||||||
|
import torch
|
||||||
|
import argparse
|
||||||
|
import numpy as np
|
||||||
|
from torch.nn import functional as F
|
||||||
|
from pytorch_msssim import ssim_matlab
|
||||||
|
from model.RIFE import Model
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
model = Model()
|
||||||
|
model.load_model('./train_log')
|
||||||
|
model.eval()
|
||||||
|
model.device()
|
||||||
|
|
||||||
|
name = ['Beanbags', 'Dimetrodon', 'DogDance', 'Grove2', 'Grove3', 'Hydrangea', 'MiniCooper', 'RubberWhale', 'Urban2', 'Urban3', 'Venus', 'Walking']
|
||||||
|
IE_list = []
|
||||||
|
for i in name:
|
||||||
|
i0 = cv2.imread('other-data/{}/frame10.png'.format(i)).transpose(2, 0, 1) / 255.
|
||||||
|
i1 = cv2.imread('other-data/{}/frame11.png'.format(i)).transpose(2, 0, 1) / 255.
|
||||||
|
gt = cv2.imread('other-gt-interp/{}/frame10i11.png'.format(i))
|
||||||
|
h, w = i0.shape[1], i0.shape[2]
|
||||||
|
imgs = torch.zeros([1, 6, 480, 640])
|
||||||
|
ph = (480 - h) // 2
|
||||||
|
pw = (640 - w) // 2
|
||||||
|
imgs[:, :3, :h, :w] = torch.from_numpy(i0).unsqueeze(0).float()
|
||||||
|
imgs[:, 3:, :h, :w] = torch.from_numpy(i1).unsqueeze(0).float()
|
||||||
|
I0 = imgs[:, :3]
|
||||||
|
I2 = imgs[:, 3:]
|
||||||
|
pred = model.inference(I0, I2)
|
||||||
|
out = pred[0].cpu().numpy().transpose(1, 2, 0)
|
||||||
|
out = np.round(out[:h, :w] * 255)
|
||||||
|
IE_list.append(np.abs((out - gt * 1.0)).mean())
|
||||||
|
print(np.mean(IE_list))
|
||||||
Reference in New Issue
Block a user