Files
ECCV2022-RIFE/train_WIP.py

149 lines
6.8 KiB
Python
Raw Normal View History

2020-11-16 18:34:09 +08:00
import os
import cv2
import math
import time
import torch
import numpy as np
import random
import argparse
import torch.distributed as dist
torch.distributed.init_process_group(backend="nccl", world_size=4)
local_rank = torch.distributed.get_rank()
torch.cuda.set_device(local_rank)
device = torch.device("cuda", local_rank)
seed = 1234
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = True
from model import Model
from dataset import *
from torch.utils.data import DataLoader, Dataset
from torch.utils.tensorboard import SummaryWriter
from util import *
from torch.utils.data.distributed import DistributedSampler
log_path = 'train_log'
if local_rank == 0:
writer = SummaryWriter(log_path + '/train')
writer_val = SummaryWriter(log_path + '/validate')
def get_learning_rate(step):
if step < 2000:
mul = step / 2000.
else:
mul = np.cos((step - 2000) / (args.epoch * args.step_per_epoch - 2000.) * math.pi) * 0.5 + 0.5
return 5e-4 * mul
def train(model):
step = 0
nr_eval = 0
dataset = VimeoDataset('train')
sampler = DistributedSampler(dataset)
train_data = DataLoader(dataset, batch_size=args.batch_size, num_workers=8, pin_memory=True, drop_last=True, sampler=sampler)
args.step_per_epoch = train_data.__len__()
dataset_val = VimeoDataset('validation')
val_data = DataLoader(dataset_val, batch_size=16, pin_memory=True, num_workers=8)
evaluate(model, val_data, nr_eval)
model.save_model(log_path, local_rank)
model.load_model(log_path, local_rank)
print('training...')
time_stamp = time.time()
for epoch in range(args.epoch):
sampler.set_epoch(epoch)
for i, data in enumerate(train_data):
data_time_interval = time.time() - time_stamp
time_stamp = time.time()
data_gpu, flow_gt = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
flow_gt = flow_gt.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
mul = np.cos(step / (args.epoch * args.step_per_epoch) * math.pi) * 0.5 + 0.5
learning_rate = get_learning_rate(step)
pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, learning_rate, mul, True, flow_gt)
train_time_interval = time.time() - time_stamp
time_stamp = time.time()
if step % 100 == 1 and local_rank == 0:
writer.add_scalar('learning_rate', learning_rate, step)
writer.add_scalar('loss_l1', loss_l1, step)
writer.add_scalar('loss_flow', loss_flow, step)
writer.add_scalar('loss_cons', loss_cons, step)
writer.add_scalar('loss_ter', loss_ter, step)
if step % 1000 == 1 and local_rank == 0:
gt = (gt.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
pred = (pred.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
merged_img = (merged_img.permute(0, 2, 3, 1).detach().cpu().numpy() * 255).astype('uint8')
flow = flow.permute(0, 2, 3, 1).detach().cpu().numpy()
flow_mask = flow_mask.permute(0, 2, 3, 1).detach().cpu().numpy()
flow_gt = flow_gt.permute(0, 2, 3, 1).detach().cpu().numpy()
for i in range(5):
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
writer.add_image(str(i) + '/img', imgs, step, dataformats='HWC')
writer.add_image(str(i) + '/flow', flow2rgb(flow[i]), step, dataformats='HWC')
writer.add_image(str(i) + '/flow_gt', flow2rgb(flow_gt[i]), step, dataformats='HWC')
writer.add_image(str(i) + '/flow_mask', flow2rgb(flow[i] * flow_mask[i]), step, dataformats='HWC')
writer.flush()
if local_rank == 0:
print('epoch:{} {}/{} time:{:.2f}+{:.2f} loss_l1:{:.4e}'.format(epoch, i, args.step_per_epoch, data_time_interval, train_time_interval, loss_l1))
step += 1
nr_eval += 1
if nr_eval % 5 == 0:
evaluate(model, val_data, step)
model.save_model(log_path, local_rank)
dist.barrier()
def evaluate(model, val_data, nr_eval):
loss_l1_list = []
loss_cons_list = []
loss_ter_list = []
loss_flow_list = []
psnr_list = []
time_stamp = time.time()
for i, data in enumerate(val_data):
data_gpu, flow_gt = data
data_gpu = data_gpu.to(device, non_blocking=True) / 255.
flow_gt = flow_gt.to(device, non_blocking=True)
imgs = data_gpu[:, :6]
gt = data_gpu[:, 6:9]
with torch.no_grad():
pred, merged_img, flow, loss_l1, loss_flow, loss_cons, loss_ter, flow_mask = model.update(imgs, gt, training=False)
loss_l1_list.append(loss_l1.cpu().numpy())
loss_flow_list.append(loss_flow.cpu().numpy())
loss_ter_list.append(loss_ter.cpu().numpy())
loss_cons_list.append(loss_cons.cpu().numpy())
for j in range(gt.shape[0]):
psnr = -10 * math.log10(torch.mean((gt[j] - pred[j]) * (gt[j] - pred[j])).cpu().data)
psnr_list.append(psnr)
gt = (gt.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
pred = (pred.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
merged_img = (merged_img.permute(0, 2, 3, 1).cpu().numpy() * 255).astype('uint8')
flow = flow.permute(0, 2, 3, 1).cpu().numpy()
if i == 0 and local_rank == 0:
for j in range(5):
imgs = np.concatenate((merged_img[i], pred[i], gt[i]), 1)[:, :, ::-1]
writer_val.add_image(str(i) + '/img', imgs.copy(), nr_eval, dataformats='HWC')
writer_val.add_image(str(i) + '/flow', flow2rgb(flow[i][:, :, ::-1]), nr_eval, dataformats='HWC')
eval_time_interval = time.time() - time_stamp
if local_rank == 0:
print('eval time: {}'.format(eval_time_interval))
writer_val.add_scalar('loss_l1', np.array(loss_l1_list).mean(), nr_eval)
writer_val.add_scalar('loss_flow', np.array(loss_flow_list).mean(), nr_eval)
writer_val.add_scalar('loss_cons', np.array(loss_cons_list).mean(), nr_eval)
writer_val.add_scalar('loss_ter', np.array(loss_ter_list).mean(), nr_eval)
writer_val.add_scalar('psnr', np.array(psnr_list).mean(), nr_eval)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='slomo')
parser.add_argument('--epoch', default=300, type=int)
parser.add_argument('--batch_size', default=16, type=int, help='minibatch size')
parser.add_argument('--local_rank', default=0, type=int, help='local rank')
args = parser.parse_args()
model = Model(args.local_rank)
train(model)