diff --git a/train.py b/train.py index 93a4dc4..6d56203 100644 --- a/train.py +++ b/train.py @@ -55,11 +55,13 @@ def train(model, local_rank): for i, data in enumerate(train_data): data_time_interval = time.time() - time_stamp time_stamp = time.time() - data_gpu = data.to(device, non_blocking=True) / 255. + data_gpu, timestep = data + data_gpu = data_gpu.to(device, non_blocking=True) / 255. + timestep = timestep.to(device, non_blocking=True) imgs = data_gpu[:, :6] gt = data_gpu[:, 6:9] learning_rate = get_learning_rate(step) - pred, info = model.update(imgs, gt, learning_rate, training=True) + pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm train_time_interval = time.time() - time_stamp time_stamp = time.time() if step % 200 == 1 and local_rank == 0: @@ -97,7 +99,8 @@ def evaluate(model, val_data, nr_eval, local_rank, writer_val): psnr_list_teacher = [] time_stamp = time.time() for i, data in enumerate(val_data): - data_gpu = data.to(device, non_blocking=True) / 255. + data_gpu, timestep = data + data_gpu = data_gpu.to(device, non_blocking=True) / 255. imgs = data_gpu[:, :6] gt = data_gpu[:, 6:9] with torch.no_grad():