Files
ECCV2022-RIFE/train.py

156 lines
7.0 KiB
Python
Raw Normal View History

2020-11-16 18:34:09 +08:00
import os
import cv2
import math
import time
import torch
2021-08-13 16:34:41 +08:00
import torch.distributed as dist
2020-11-16 18:34:09 +08:00
import numpy as np
import random
import argparse
2020-12-02 18:17:06 +08:00
from model.RIFE import Model
2020-11-16 18:34:09 +08:00
from dataset import *
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data.distributed import DistributedSampler
2021-08-13 16:34:41 +08:00
device = torch.device("cuda")
2022-01-04 16:29:13 +08:00
log_path = 'train_log'
2021-08-13 16:34:41 +08:00
2020-11-16 18:34:09 +08:00
def get_learning_rate(step):
if step < 2000:
mul = step / 2000.
2022-07-21 14:10:14 +08:00
return 3e-4 * mul
2020-11-16 18:34:09 +08:00
else:
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
2022-07-21 14:10:14 +08:00
return (3e-4 - 3e-6) * mul + 3e-6
2020-11-16 18:34:09 +08:00
2020-12-02 18:17:06 +08:00
def flow2rgb(flow_map_np):
h, w, _ = flow_map_np.shape
rgb_map = np.ones((h, w, 3)).astype(np.float32)
normalized_flow_map = flow_map_np / (np.abs(flow_map_np).max())
rgb_map[:, :, 0] += normalized_flow_map[:, :, 0]
rgb_map[:, :, 1] -= 0.5 * (normalized_flow_map[:, :, 0] + normalized_flow_map[:, :, 1])
rgb_map[:, :, 2] += normalized_flow_map[:, :, 1]
return rgb_map.clip(0, 1)
2020-12-14 11:50:18 +08:00
def train(model, local_rank):
if local_rank == 0:
2021-08-13 16:34:41 +08:00
writer = SummaryWriter('train')
writer_val = SummaryWriter('validate')
2022-11-17 11:31:43 +08:00
else:
writer = None
writer_val = None
2020-11-16 18:34:09 +08:00
step = 0
nr_eval = 0
dataset = VimeoDataset('train')
sampler = DistributedSampler(dataset)
train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
args.step_per_epoch = train_data.__len__()
dataset_val = VimeoDataset('validation')
2021-08-13 16:34:41 +08:00
val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
2020-11-16 18:34:09 +08:00
print('training...')
time_stamp = time.time()
for epoch in range(args.epoch):
sampler.set_epoch(epoch)
for i, data in enumerate(train_data):
data_time_interval = time.time() - time_stamp
time_stamp = time.time()
2022-04-11 11:45:56 +08:00
data_gpu, timestep = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
timestep = timestep.to(device, non_blocking=True)
2020-11-16 18:34:09 +08:00
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
2022-04-15 11:30:05 +08:00
learning_rate = get_learning_rate(step) * args.world_size / 4
2022-04-11 11:45:56 +08:00
pred, info = model.update(imgs, gt, learning_rate, training=True) # pass timestep if you are training RIFEm
2020-11-16 18:34:09 +08:00
train_time_interval = time.time() - time_stamp
time_stamp = time.time()
2021-08-13 16:34:41 +08:00
if step % 200 == 1 and local_rank == 0:
2020-11-16 18:34:09 +08:00
writer.add_scalar('learning_rate', learning_rate, step)
2021-08-13 16:34:41 +08:00
writer.add_scalar('loss/l1', info['loss_l1'], step)
writer.add_scalar('loss/tea', info['loss_tea'], step)
writer.add_scalar('loss/distill', info['loss_distill'], step)
2020-11-16 18:34:09 +08:00
if step % 1000 == 1 and local_rank == 0:
gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
2021-08-13 16:34:41 +08:00
mask = (torch.cat((info['mask'], info['mask_tea']), 3).permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
2020-11-16 18:34:09 +08:00
pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
2021-08-13 16:34:41 +08:00
merged_img = (info['merged_tea'].permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
flow0 = info['flow'].permute(0, 2, 3, 1).detach().cpu().numpy()
flow1 = info['flow_tea'].permute(0, 2, 3, 1).detach().cpu().numpy()
2020-11-16 18:34:09 +08:00
for i in range(5):
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
2021-08-13 16:34:41 +08:00
writer.add_image(str(i) + '/flow', np.concatenate((flow2rgb(flow0[i]), flow2rgb(flow1[i])), 1), step, dataformats='HWC')
writer.add_image(str(i) + '/mask', mask[i], step, dataformats='HWC')
2020-11-16 18:34:09 +08:00
writer.flush()
if local_rank == 0:
2021-08-13 16:34:41 +08:00
print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, info['loss_l1']))
2020-11-16 18:34:09 +08:00
step += 1
nr_eval += 1
if nr_eval % 5 == 0:
2022-01-04 16:29:13 +08:00
evaluate(model, val_data, step, local_rank, writer_val)
2020-11-16 18:34:09 +08:00
model.save_model(log_path, local_rank)
dist.barrier()
2022-01-04 16:29:13 +08:00
def evaluate(model, val_data, nr_eval, local_rank, writer_val):
2020-11-16 18:34:09 +08:00
loss_l1_list = []
2021-08-13 16:34:41 +08:00
loss_distill_list = []
loss_tea_list = []
2020-11-16 18:34:09 +08:00
psnr_list = []
2021-08-13 16:34:41 +08:00
psnr_list_teacher = []
2020-11-16 18:34:09 +08:00
time_stamp = time.time()
for i, data in enumerate(val_data):
2022-04-11 11:45:56 +08:00
data_gpu, timestep = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
2020-11-16 18:34:09 +08:00
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
with torch.no_grad():
2021-08-13 16:34:41 +08:00
pred, info = model.update(imgs, gt, training=False)
merged_img = info['merged_tea']
loss_l1_list.append(info['loss_l1'].cpu().numpy())
loss_tea_list.append(info['loss_tea'].cpu().numpy())
loss_distill_list.append(info['loss_distill'].cpu().numpy())
2020-11-16 18:34:09 +08:00
for j in range(gt.shape[0]):
psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
psnr_list.append(psnr)
2021-08-13 16:34:41 +08:00
psnr = -10 * math.log10(torch.mean((merged_img[j] - gt[j]) * (merged_img[j] - gt[j])).cpu().data)
psnr_list_teacher.append(psnr)
2020-11-16 18:34:09 +08:00
gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
2021-08-13 16:34:41 +08:00
flow0 = info['flow'].permute(0, 2, 3, 1).cpu().numpy()
flow1 = info['flow_tea'].permute(0, 2, 3, 1).cpu().numpy()
2020-11-16 18:34:09 +08:00
if i == 0 and local_rank == 0:
2021-08-13 16:34:41 +08:00
for j in range(10):
imgs = np.concatenate((merged_img[j], pred[j], gt[j]), 1)[:, :, ::-1]
2022-01-06 17:54:52 +08:00
writer_val.add_image(str(j) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
2021-08-13 16:34:41 +08:00
writer_val.add_image(str(j) + '/flow', flow2rgb(flow0[j][:, :, ::-1]), nr_eval, dataformats='HWC')
2020-11-16 18:34:09 +08:00
eval_time_interval = time.time() - time_stamp
2021-08-13 16:34:41 +08:00
if local_rank != 0:
return
writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval)
writer_val.add_scalar('psnr_teacher', np.array(psnr_list_teacher).mean(), nr_eval)
2020-11-16 18:34:09 +08:00
if __name__ == "__main__":
2021-08-13 16:34:41 +08:00
parser = argparse.ArgumentParser()
2020-11-16 18:34:09 +08:00
parser.add_argument('--epoch', default=300, type=int)
2021-08-13 16:34:41 +08:00
parser.add_argument('--batch_size', default=16, type=int, help='minibatch size')
2020-11-16 18:34:09 +08:00
parser.add_argument('--local_rank', default=0, type=int, help='local rank')
2020-12-14 11:50:18 +08:00
parser.add_argument('--world_size', default=4, type=int, help='world size')
2020-11-16 18:34:09 +08:00
args = parser.parse_args()
2020-12-14 11:51:34 +08:00
torch.distributed.init_process_group(backend="nccl", world_size=args.world_size)
2020-12-14 11:50:18 +08:00
torch.cuda.set_device(args.local_rank)
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
2020-12-14 12:02:46 +08:00
model = Model(args.local_rank)
2020-12-14 11:50:18 +08:00
train(model, args.local_rank)
2020-11-16 18:34:09 +08:00