Update train.py

This commit is contained in:
hzwer
2021-09-09 14:42:09 +08:00
committed by GitHub
parent c663bfeaff
commit de92bf2f92

View File

@@ -74,7 +74,6 @@ def train(model, local_rank):
merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy()
for i in range(5):
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
@@ -98,9 +97,7 @@ def evaluate(model, val_data, nr_eval):
psnr_list_teacher = []
time_stamp = time.time()
for i, data in enumerate(val_data):
data_gpu, flow_gt = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
flow_gt = flow_gt.to(device, non_blocking=True)
data_gpu = data.to(device, non_blocking=True) / 255.
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
with torch.no_grad():
@@ -124,7 +121,6 @@ def evaluate(model, val_data, nr_eval):
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
writer_val.add_image(str(j) + '/flow_gt', flow2rgb(flow1[j][:, :, ::-1]), nr_eval, dataformats='HWC')
eval_time_interval = time.time() - time_stamp