Files
ECCV2022-RIFE/dataset.py

75 lines
2.5 KiB
Python
Raw Normal View History

2020-11-23 18:46:04 +08:00
import cv2
import ast
import torch
import numpy as np
import random
from torch.utils.data import DataLoader, Dataset
cv2.setNumThreads(1)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class VimeoDataset(Dataset):
def __init__(self, dataset_name, batch_size=32):
self.batch_size = batch_size
2020-12-02 18:17:06 +08:00
self.path = './dataset/'
2020-11-23 18:46:04 +08:00
self.dataset_name = dataset_name
self.load_data()
self.h = 256
self.w = 448
xx = np.arange(0, self.w).reshape(1,-1).repeat(self.h,0)
yy = np.arange(0, self.h).reshape(-1,1).repeat(self.w,1)
self.grid = np.stack((xx,yy),2).copy()
def __len__(self):
return len(self.meta_data)
def load_data(self):
2020-12-02 18:17:06 +08:00
self.train_data = []
self.flow_data = []
self.val_data = []
2020-11-23 18:46:04 +08:00
if self.dataset_name == 'train':
self.meta_data = self.train_data
else:
self.meta_data = self.val_data
self.nr_sample = len(self.meta_data)
2021-08-13 16:58:05 +08:00
def aug(self, img0, gt, img1, h, w):
2020-11-23 18:46:04 +08:00
ih, iw, _ = img0.shape
x = np.random.randint(0, ih - h + 1)
y = np.random.randint(0, iw - w + 1)
img0 = img0[x:x+h, y:y+w, :]
img1 = img1[x:x+h, y:y+w, :]
gt = gt[x:x+h, y:y+w, :]
2021-08-13 16:58:05 +08:00
return img0, gt, img1
2020-11-23 18:46:04 +08:00
def getimg(self, index):
data = self.meta_data[index]
2020-12-02 18:17:06 +08:00
img0 = data[0:3].transpose(1, 2, 0)
img1 = data[3:6].transpose(1, 2, 0)
gt = data[6:9].transpose(1, 2, 0)
2021-08-13 16:58:05 +08:00
return img0, gt, img1
2020-11-23 18:46:04 +08:00
def __getitem__(self, index):
2021-08-13 16:58:05 +08:00
img0, gt, img1 = self.getimg(index)
2020-11-23 18:46:04 +08:00
if self.dataset_name == 'train':
2021-08-13 16:58:05 +08:00
img0, gt, img1 = self.aug(img0, gt, img1, 224, 224)
2020-11-23 18:46:04 +08:00
if random.uniform(0, 1) < 0.5:
img0 = img0[:, :, ::-1]
img1 = img1[:, :, ::-1]
gt = gt[:, :, ::-1]
if random.uniform(0, 1) < 0.5:
img0 = img0[::-1]
img1 = img1[::-1]
gt = gt[::-1]
if random.uniform(0, 1) < 0.5:
img0 = img0[:, ::-1]
img1 = img1[:, ::-1]
gt = gt[:, ::-1]
if random.uniform(0, 1) < 0.5:
tmp = img1
img1 = img0
img0 = tmp
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
2021-08-13 16:58:05 +08:00
return torch.cat((img0, img1, gt), 0)