mirror of
https://github.com/hzwer/ECCV2022-RIFE.git
synced 2026-02-24 04:19:41 +01:00
Update train.py
This commit is contained in:
6
train.py
6
train.py
@@ -74,7 +74,6 @@ def train(model, local_rank):
|
||||
merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
|
||||
flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy()
|
||||
for i in range(5):
|
||||
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
|
||||
writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
|
||||
@@ -98,9 +97,7 @@ def evaluate(model, val_data, nr_eval):
|
||||
psnr_list_teacher = []
|
||||
time_stamp = time.time()
|
||||
for i, data in enumerate(val_data):
|
||||
data_gpu, flow_gt = data
|
||||
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
|
||||
flow_gt = flow_gt.to(device, non_blocking=True)
|
||||
data_gpu = data.to(device, non_blocking=True) / 255.
|
||||
imgs = data_gpu[:, :6]
|
||||
gt = data_gpu[:, 6:9]
|
||||
with torch.no_grad():
|
||||
@@ -124,7 +121,6 @@ def evaluate(model, val_data, nr_eval):
|
||||
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
|
||||
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
|
||||
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
|
||||
writer_val.add_image(str(j) + '/flow_gt', flow2rgb(flow1[j][:, :, ::-1]), nr_eval, dataformats='HWC')
|
||||
|
||||
eval_time_interval = time.time() - time_stamp
|
||||
|
||||
|
||||
Reference in New Issue
Block a user