mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2025-12-16 08:27:45 +01:00
Update dataset.py
This commit is contained in:
42
dataset.py
42
dataset.py
@@ -35,9 +35,8 @@ class VimeoDataset(Dataset):
|
|||||||
self.meta_data = self.testlist
|
self.meta_data = self.testlist
|
||||||
else:
|
else:
|
||||||
self.meta_data = self.trainlist[cnt:]
|
self.meta_data = self.trainlist[cnt:]
|
||||||
|
|
||||||
|
def crop(self, img0, gt, img1, h, w):
|
||||||
def aug(self, img0, gt, img1, h, w):
|
|
||||||
ih, iw, _ = img0.shape
|
ih, iw, _ = img0.shape
|
||||||
x = np.random.randint(0, ih - h + 1)
|
x = np.random.randint(0, ih - h + 1)
|
||||||
y = np.random.randint(0, iw - w + 1)
|
y = np.random.randint(0, iw - w + 1)
|
||||||
@@ -54,12 +53,24 @@ class VimeoDataset(Dataset):
|
|||||||
img0 = cv2.imread(imgpaths[0])
|
img0 = cv2.imread(imgpaths[0])
|
||||||
gt = cv2.imread(imgpaths[1])
|
gt = cv2.imread(imgpaths[1])
|
||||||
img1 = cv2.imread(imgpaths[2])
|
img1 = cv2.imread(imgpaths[2])
|
||||||
return img0, gt, img1
|
timestep = 0.5
|
||||||
|
return img0, gt, img1, timestep
|
||||||
|
|
||||||
|
# RIFEm with Vimeo-Septuplet
|
||||||
|
# imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png']
|
||||||
|
# ind = [0, 1, 2, 3, 4, 5, 6]
|
||||||
|
# random.shuffle(ind)
|
||||||
|
# ind = ind[:3]
|
||||||
|
# ind.sort()
|
||||||
|
# img0 = cv2.imread(imgpaths[ind[0]])
|
||||||
|
# gt = cv2.imread(imgpaths[ind[1]])
|
||||||
|
# img1 = cv2.imread(imgpaths[ind[2]])
|
||||||
|
# timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6)
|
||||||
|
|
||||||
def __getitem__(self, index):
|
def __getitem__(self, index):
|
||||||
img0, gt, img1 = self.getimg(index)
|
img0, gt, img1, timestep = self.getimg(index)
|
||||||
if self.dataset_name == 'train':
|
if self.dataset_name == 'train':
|
||||||
img0, gt, img1 = self.aug(img0, gt, img1, 224, 224)
|
img0, gt, img1 = self.crop(img0, gt, img1, 224, 224)
|
||||||
if random.uniform(0, 1) < 0.5:
|
if random.uniform(0, 1) < 0.5:
|
||||||
img0 = img0[:, :, ::-1]
|
img0 = img0[:, :, ::-1]
|
||||||
img1 = img1[:, :, ::-1]
|
img1 = img1[:, :, ::-1]
|
||||||
@@ -76,8 +87,23 @@ class VimeoDataset(Dataset):
|
|||||||
tmp = img1
|
tmp = img1
|
||||||
img1 = img0
|
img1 = img0
|
||||||
img0 = tmp
|
img0 = tmp
|
||||||
# timestep = 1 - timestep
|
timestep = 1 - timestep
|
||||||
|
# random rotation
|
||||||
|
p = random.uniform(0, 1)
|
||||||
|
if p < 0.25:
|
||||||
|
img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE)
|
||||||
|
elif p < 0.5:
|
||||||
|
img0 = cv2.rotate(img0, cv2.ROTATE_180)
|
||||||
|
gt = cv2.rotate(gt, cv2.ROTATE_180)
|
||||||
|
img1 = cv2.rotate(img1, cv2.ROTATE_180)
|
||||||
|
elif p < 0.75:
|
||||||
|
img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||||
|
gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||||
|
img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
||||||
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
|
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
|
||||||
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
|
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
|
||||||
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
|
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
|
||||||
return torch.cat((img0, img1, gt), 0)
|
timestep = torch.tensor(timestep).reshape(1, 1, 1)
|
||||||
|
return torch.cat((img0, img1, gt), 0), timestep
|
||||||
|
|||||||
Reference in New Issue
Block a user