From 981ae76ca2023f8213a24568ebfb00c4d0a8a76b Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Mon, 11 Apr 2022 11:43:19 +0800 Subject: [PATCH] Update dataset.py --- dataset.py | 42 ++++++++++++++++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 8 deletions(-) diff --git a/dataset.py b/dataset.py index 392253b..8221e0d 100644 --- a/dataset.py +++ b/dataset.py @@ -35,9 +35,8 @@ class VimeoDataset(Dataset): self.meta_data = self.testlist else: self.meta_data = self.trainlist[cnt:] - - - def aug(self, img0, gt, img1, h, w): + + def crop(self, img0, gt, img1, h, w): ih, iw, _ = img0.shape x = np.random.randint(0, ih - h + 1) y = np.random.randint(0, iw - w + 1) @@ -54,12 +53,24 @@ class VimeoDataset(Dataset): img0 = cv2.imread(imgpaths[0]) gt = cv2.imread(imgpaths[1]) img1 = cv2.imread(imgpaths[2]) - return img0, gt, img1 + timestep = 0.5 + return img0, gt, img1, timestep + + # RIFEm with Vimeo-Septuplet + # imgpaths = [imgpath + '/im1.png', imgpath + '/im2.png', imgpath + '/im3.png', imgpath + '/im4.png', imgpath + '/im5.png', imgpath + '/im6.png', imgpath + '/im7.png'] + # ind = [0, 1, 2, 3, 4, 5, 6] + # random.shuffle(ind) + # ind = ind[:3] + # ind.sort() + # img0 = cv2.imread(imgpaths[ind[0]]) + # gt = cv2.imread(imgpaths[ind[1]]) + # img1 = cv2.imread(imgpaths[ind[2]]) + # timestep = (ind[1] - ind[0]) * 1.0 / (ind[2] - ind[0] + 1e-6) def __getitem__(self, index): - img0, gt, img1 = self.getimg(index) + img0, gt, img1, timestep = self.getimg(index) if self.dataset_name == 'train': - img0, gt, img1 = self.aug(img0, gt, img1, 224, 224) + img0, gt, img1 = self.crop(img0, gt, img1, 224, 224) if random.uniform(0, 1) < 0.5: img0 = img0[:, :, ::-1] img1 = img1[:, :, ::-1] @@ -76,8 +87,23 @@ class VimeoDataset(Dataset): tmp = img1 img1 = img0 img0 = tmp - # timestep = 1 - timestep + timestep = 1 - timestep + # random rotation + p = random.uniform(0, 1) + if p < 0.25: + img0 = cv2.rotate(img0, cv2.ROTATE_90_CLOCKWISE) + gt = cv2.rotate(gt, cv2.ROTATE_90_CLOCKWISE) + img1 = cv2.rotate(img1, cv2.ROTATE_90_CLOCKWISE) + elif p < 0.5: + img0 = cv2.rotate(img0, cv2.ROTATE_180) + gt = cv2.rotate(gt, cv2.ROTATE_180) + img1 = cv2.rotate(img1, cv2.ROTATE_180) + elif p < 0.75: + img0 = cv2.rotate(img0, cv2.ROTATE_90_COUNTERCLOCKWISE) + gt = cv2.rotate(gt, cv2.ROTATE_90_COUNTERCLOCKWISE) + img1 = cv2.rotate(img1, cv2.ROTATE_90_COUNTERCLOCKWISE) img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1) img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1) gt = torch.from_numpy(gt.copy()).permute(2, 0, 1) - return torch.cat((img0, img1, gt), 0) + timestep = torch.tensor(timestep).reshape(1, 1, 1) + return torch.cat((img0, img1, gt), 0), timestep