Update train.py

This commit is contained in:
hzwer
2022-04-11 11:45:56 +08:00
committed by GitHub
parent 981ae76ca2
commit 0f0b1ef0c6

View File

@@ -55,11 +55,13 @@ def train(model, local_rank):
for i, data in enumerate(train_data):
data_time_interval = time.time() - time_stamp
time_stamp = time.time()
data_gpu = data.to(device, non_blocking=True) / 255.
data_gpu, timestep = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
timestep = timestep.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
learning_rate = get_learning_rate(step)
pred, info = model.update(imgs, gt, learning_rate, training=True)
pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
train_time_interval = time.time() - time_stamp
time_stamp = time.time()
if step % 200 == 1 and local_rank == 0:
@@ -97,7 +99,8 @@ def evaluate(model, val_data, nr_eval, local_rank, writer_val):
psnr_list_teacher = []
time_stamp = time.time()
for i, data in enumerate(val_data):
data_gpu = data.to(device, non_blocking=True) / 255.
data_gpu, timestep = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
with torch.no_grad():