From de92bf2f9234dfd6676828bf74592266b36b63bd Mon Sep 17 00:00:00 2001 From: hzwer <598460606@163.com> Date: Thu, 9 Sep 2021 14:42:09 +0800 Subject: [PATCH] Update train.py --- train.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/train.py b/train.py index 0c8fd61..b386606 100644 --- a/train.py +++ b/train.py @@ -74,7 +74,6 @@ def train(model, local_rank): merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8') flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy() flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() - flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy() for i in range(5): imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1] writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC') @@ -98,9 +97,7 @@ def evaluate(model, val_data, nr_eval): psnr_list_teacher = [] time_stamp = time.time() for i, data in enumerate(val_data): - data_gpu, flow_gt = data - data_gpu = data_gpu.to(device, non_blocking=True) / 255. - flow_gt = flow_gt.to(device, non_blocking=True) + data_gpu = data.to(device, non_blocking=True) / 255. imgs = data_gpu[:, :6] gt = data_gpu[:, 6:9] with torch.no_grad(): @@ -124,7 +121,6 @@ def evaluate(model, val_data, nr_eval): imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1] writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC') - writer_val.add_image(str(j) + '/flow_gt', flow2rgb(flow1[j][:, :, ::-1]), nr_eval, dataformats='HWC') eval_time_interval = time.time() - time_stamp