Update train.py

This commit is contained in:
Zhewei Huang
2022-01-04 16:29:13 +08:00
committed by GitHub
parent b64256c134
commit 8eb9503046

View File

@@ -16,7 +16,7 @@ from torch.utils.data.distributed import DistributedSampler
device = torch.device("cuda")
exp = os.path.abspath('.').split('/')[-1]
log_path = 'train_log'
def get_learning_rate(step):
if step < 2000:
@@ -85,11 +85,11 @@ def train(model, local_rank):
step += 1
nr_eval += 1
if nr_eval % 5 == 0:
evaluate(model, val_data, step)
evaluate(model, val_data, step, local_rank, writer_val)
model.save_model(log_path, local_rank)
dist.barrier()
def evaluate(model, val_data, nr_eval):
def evaluate(model, val_data, nr_eval, local_rank, writer_val):
loss_l1_list = []
loss_distill_list = []
loss_tea_list = []
@@ -119,7 +119,7 @@ def evaluate(model, val_data, nr_eval):
if i == 0 and local_rank == 0:
for j in range(10):
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
..add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
eval_time_interval = time.time() - time_stamp