New dataset WIP

This commit is contained in:
hzwer
2021-08-13 16:58:05 +08:00
parent a9ed2bac23
commit 3ba310299d

View File

@@ -26,41 +26,32 @@ class VimeoDataset(Dataset):
self.train_data = []
self.flow_data = []
self.val_data = []
for i in range(100):
f = np.load('dataset/{}.npz'.format(i))
if i < 80:
self.train_data.append(f['i0i1gt'])
self.flow_data.append(f['ft0ft1'])
else:
self.val_data.append(f['i0i1gt'])
if self.dataset_name == 'train':
self.meta_data = self.train_data
else:
self.meta_data = self.val_data
self.nr_sample = len(self.meta_data)
def aug(self, img0, gt, img1, flow_gt, h, w):
def aug(self, img0, gt, img1, h, w):
ih, iw, _ = img0.shape
x = np.random.randint(0, ih - h + 1)
y = np.random.randint(0, iw - w + 1)
img0 = img0[x:x+h, y:y+w, :]
img1 = img1[x:x+h, y:y+w, :]
gt = gt[x:x+h, y:y+w, :]
flow_gt = flow_gt[x:x+h, y:y+w, :]
return img0, gt, img1, flow_gt
return img0, gt, img1
def getimg(self, index):
data = self.meta_data[index]
img0 = data[0:3].transpose(1, 2, 0)
img1 = data[3:6].transpose(1, 2, 0)
gt = data[6:9].transpose(1, 2, 0)
flow_gt = (self.flow_data[index]).transpose(1, 2, 0)
return img0, gt, img1, flow_gt
return img0, gt, img1
def __getitem__(self, index):
img0, gt, img1, flow_gt = self.getimg(index)
img0, gt, img1 = self.getimg(index)
if self.dataset_name == 'train':
img0, gt, img1, flow_gt = self.aug(img0, gt, img1, flow_gt, 224, 224)
img0, gt, img1 = self.aug(img0, gt, img1, 224, 224)
if random.uniform(0, 1) < 0.5:
img0 = img0[:, :, ::-1]
img1 = img1[:, :, ::-1]
@@ -69,23 +60,15 @@ class VimeoDataset(Dataset):
img0 = img0[::-1]
img1 = img1[::-1]
gt = gt[::-1]
flow_gt = flow_gt[::-1]
flow_gt = np.concatenate((flow_gt[:, :, 0:1], -flow_gt[:, :, 1:2], flow_gt[:, :, 2:3], -flow_gt[:, :, 3:4]), 2)
if random.uniform(0, 1) < 0.5:
img0 = img0[:, ::-1]
img1 = img1[:, ::-1]
gt = gt[:, ::-1]
flow_gt = flow_gt[:, ::-1]
flow_gt = np.concatenate((-flow_gt[:, :, 0:1], flow_gt[:, :, 1:2], -flow_gt[:, :, 2:3], flow_gt[:, :, 3:4]), 2)
if random.uniform(0, 1) < 0.5:
tmp = img1
img1 = img0
img0 = tmp
flow_gt = np.concatenate((flow_gt[:, :, 2:4], flow_gt[:, :, 0:2]), 2)
else:
flow_gt = np.zeros((256, 448, 4))
flow_gt = torch.from_numpy(flow_gt.copy()).permute(2, 0, 1)
img0 = torch.from_numpy(img0.copy()).permute(2, 0, 1)
img1 = torch.from_numpy(img1.copy()).permute(2, 0, 1)
gt = torch.from_numpy(gt.copy()).permute(2, 0, 1)
return torch.cat((img0, img1, gt), 0), flow_gt
return torch.cat((img0, img1, gt), 0)