mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Update train.py
This commit is contained in:
9
train.py
9
train.py
@@ -55,11 +55,13 @@ def train(model, local_rank):
|
||||
for i, data in enumerate(train_data):
|
||||
data_time_interval = time.time() - time_stamp
|
||||
time_stamp = time.time()
|
||||
data_gpu = data.to(device, non_blocking=True) / 255.
|
||||
data_gpu, timestep = data
|
||||
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
|
||||
timestep = timestep.to(device, non_blocking=True)
|
||||
imgs = data_gpu[:, :6]
|
||||
gt = data_gpu[:, 6:9]
|
||||
learning_rate = get_learning_rate(step)
|
||||
pred, info = model.update(imgs, gt, learning_rate, training=True)
|
||||
pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
|
||||
train_time_interval = time.time() - time_stamp
|
||||
time_stamp = time.time()
|
||||
if step % 200 == 1 and local_rank == 0:
|
||||
@@ -97,7 +99,8 @@ def evaluate(model, val_data, nr_eval, local_rank, writer_val):
|
||||
psnr_list_teacher = []
|
||||
time_stamp = time.time()
|
||||
for i, data in enumerate(val_data):
|
||||
data_gpu = data.to(device, non_blocking=True) / 255.
|
||||
data_gpu, timestep = data
|
||||
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
|
||||
imgs = data_gpu[:, :6]
|
||||
gt = data_gpu[:, 6:9]
|
||||
with torch.no_grad():
|
||||
|
||||
Reference in New Issue
Block a user