Update train.py

This commit is contained in:
hzwer
2022-04-15 11:21:20 +08:00
committed by GitHub
parent 0f0b1ef0c6
commit dd941f3ae6

View File

@@ -60,7 +60,7 @@ def train(model, local_rank):
timestep = timestep.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
learning_rate = get_learning_rate(step)
learning_rate = get_learning_rate(step) / args.world_size
pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
train_time_interval = time.time() - time_stamp
time_stamp = time.time()