diff --git a/train.py b/train.py index b5c5479..93a4dc4 100644 --- a/train.py +++ b/train.py @@ -119,7 +119,7 @@ def evaluate(model, val_data, nr_eval, local_rank, writer_val): if i == 0 and local_rank == 0: for j in range(10): imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1] - ..add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') + writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC') eval_time_interval = time.time() - time_stamp