mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Update train.py
This commit is contained in:
8
train.py
8
train.py
@@ -16,7 +16,7 @@ from torch.utils.data.distributed import DistributedSampler
|
||||
|
||||
device = torch.device("cuda")
|
||||
|
||||
exp = os.path.abspath('.').split('/')[-1]
|
||||
log_path = 'train_log'
|
||||
|
||||
def get_learning_rate(step):
|
||||
if step < 2000:
|
||||
@@ -85,11 +85,11 @@ def train(model, local_rank):
|
||||
step += 1
|
||||
nr_eval += 1
|
||||
if nr_eval % 5 == 0:
|
||||
evaluate(model, val_data, step)
|
||||
evaluate(model, val_data, step, local_rank, writer_val)
|
||||
model.save_model(log_path, local_rank)
|
||||
dist.barrier()
|
||||
|
||||
def evaluate(model, val_data, nr_eval):
|
||||
def evaluate(model, val_data, nr_eval, local_rank, writer_val):
|
||||
loss_l1_list = []
|
||||
loss_distill_list = []
|
||||
loss_tea_list = []
|
||||
@@ -119,7 +119,7 @@ def evaluate(model, val_data, nr_eval):
|
||||
if i == 0 and local_rank == 0:
|
||||
for j in range(10):
|
||||
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
|
||||
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
|
||||
..add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
|
||||
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
|
||||
|
||||
eval_time_interval = time.time() - time_stamp
|
||||
|
||||
Reference in New Issue
Block a user