diff --git a/train.py b/train.py index b386606..b5c5479 100644 --- a/train.py +++ b/train.py @@ -16,7 +16,7 @@ from torch.utils.data.distributed import DistributedSampler device = torch.device("cuda") -exp = os.path.abspath('.').split('/')[-1] +log_path = 'train_log' def get_learning_rate(step): if step < 2000: @@ -85,11 +85,11 @@ def train(model, local_rank): step += 1 nr_eval += 1 if nr_eval % 5 == 0: - evaluate(model, val_data, step) + evaluate(model, val_data, step, local_rank, writer_val) model.save_model(log_path, local_rank) dist.barrier() -def evaluate(model, val_data, nr_eval): +def evaluate(model, val_data, nr_eval, local_rank, writer_val): loss_l1_list = [] loss_distill_list = [] loss_tea_list = [] @@ -119,7 +119,7 @@ def evaluate(model, val_data, nr_eval): if i == 0 and local_rank == 0: for j in range(10): imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1] - writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') + ..add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC') writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC') eval_time_interval = time.time() - time_stamp