Files

426 lines
21 KiB
Python
Raw Permalink Normal View History

2021-08-15 14:32:22 +02:00
import argparse, os, shutil, time, random, torch, cv2, datetime, torch.utils.data, math
import torch.backends.cudnn as cudnn
import torch.optim as optim
import numpy as np
import sys
import os
abspath = os.path.abspath(__file__)
wrkdir = os.path.dirname(abspath)
print("Changing working dir to {0}".format(wrkdir))
os.chdir(os.path.dirname(wrkdir))
print("Added {0} to temporary PATH".format(wrkdir))
sys.path.append(wrkdir)
2021-08-15 14:32:22 +02:00
from torch.autograd import Variable
from utils import *
from XVFInet import *
from collections import Counter
def parse_args():
desc = "PyTorch implementation for XVFI"
parser = argparse.ArgumentParser(description=desc)
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument('--net_type', type=str, default='XVFInet', choices=['XVFInet'], help='The type of Net')
parser.add_argument('--net_object', default=XVFInet, choices=[XVFInet], help='The type of Net')
parser.add_argument('--exp_num', type=int, default=1, help='The experiment number')
parser.add_argument('--phase', type=str, default='test_custom', choices=['train', 'test', 'test_custom', 'metrics_evaluation',])
parser.add_argument('--continue_training', action='store_true', default=False, help='continue the training')
""" Information of directories """
parser.add_argument('--test_img_dir', type=str, default='./test_img_dir', help='test_img_dir path')
parser.add_argument('--text_dir', type=str, default='./text_dir', help='text_dir path')
parser.add_argument('--checkpoint_dir', type=str, default='./checkpoint_dir', help='checkpoint_dir')
parser.add_argument('--log_dir', type=str, default='./log_dir', help='Directory name to save training logs')
parser.add_argument('--dataset', default='X4K1000FPS', choices=['X4K1000FPS', 'Vimeo'],
help='Training/test Dataset')
# parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train')
# parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val')
# parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test')
parser.add_argument('--train_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/train')
parser.add_argument('--val_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/val')
parser.add_argument('--test_data_path', type=str, default='../Datasets/VIC_4K_1000FPS/test')
parser.add_argument('--vimeo_data_path', type=str, default='./vimeo_triplet')
""" Hyperparameters for Training (when [phase=='train']) """
parser.add_argument('--epochs', type=int, default=200, help='The number of epochs to run')
parser.add_argument('--freq_display', type=int, default=100, help='The number of iterations frequency for display')
parser.add_argument('--save_img_num', type=int, default=4,
help='The number of saved image while training for visualization. It should smaller than the batch_size')
parser.add_argument('--init_lr', type=float, default=1e-4, help='The initial learning rate')
parser.add_argument('--lr_dec_fac', type=float, default=0.25, help='step - lr_decreasing_factor')
parser.add_argument('--lr_milestones', type=int, default=[100, 150, 180])
parser.add_argument('--lr_dec_start', type=int, default=0,
help='When scheduler is StepLR, lr decreases from epoch at lr_dec_start')
parser.add_argument('--batch_size', type=int, default=8, help='The size of batch size.')
parser.add_argument('--weight_decay', type=float, default=0, help='for optim., weight decay (default: 0)')
parser.add_argument('--need_patch', default=True, help='get patch form image while training')
parser.add_argument('--img_ch', type=int, default=3, help='base number of channels for image')
parser.add_argument('--nf', type=int, default=64, help='base number of channels for feature maps') # 64
parser.add_argument('--module_scale_factor', type=int, default=4, help='sptial reduction for pixelshuffle')
parser.add_argument('--patch_size', type=int, default=384, help='patch size')
parser.add_argument('--num_thrds', type=int, default=4, help='number of threads for data loading')
parser.add_argument('--loss_type', default='L1', choices=['L1', 'MSE', 'L1_Charbonnier_loss'], help='Loss type')
parser.add_argument('--S_trn', type=int, default=3, help='The lowest scale depth for training')
parser.add_argument('--S_tst', type=int, default=5, help='The lowest scale depth for test')
""" Weighting Parameters Lambda for Losses (when [phase=='train']) """
parser.add_argument('--rec_lambda', type=float, default=1.0, help='Lambda for Reconstruction Loss')
""" Settings for Testing (when [phase=='test' or 'test_custom']) """
parser.add_argument('--saving_flow_flag', default=False)
parser.add_argument('--multiple', type=int, default=8, help='Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32].')
parser.add_argument('--metrics_types', type=list, default=["PSNR", "SSIM", "tOF"], choices=["PSNR", "SSIM", "tOF"])
""" Settings for test_custom (when [phase=='test_custom']) """
parser.add_argument('--custom_path', type=str, default='./custom_path', help='path for custom video containing frames')
parser.add_argument('--output', type=str, default='./interp', help='output path')
parser.add_argument('--input', type=str, default='./frames', help='input path')
parser.add_argument('--img_format', type=str, default="png")
parser.add_argument('--mdl_dir', type=str)
return check_args(parser.parse_args())
def check_args(args):
# --checkpoint_dir
check_folder(args.checkpoint_dir)
# --text_dir
check_folder(args.text_dir)
# --log_dir
check_folder(args.log_dir)
# --test_img_dir
check_folder(args.test_img_dir)
return args
def main():
args = parse_args()
if args.dataset == 'Vimeo':
if args.phase != 'test_custom':
args.multiple = 2
args.S_trn = 1
args.S_tst = 1
args.module_scale_factor = 2
args.patch_size = 256
args.batch_size = 16
print('vimeo triplet data dir : ', args.vimeo_data_path)
print("Exp:", args.exp_num)
args.model_dir = args.net_type + '_' + args.dataset + '_exp' + str(
args.exp_num) # ex) model_dir = "XVFInet_X4K1000FPS_exp1"
if args is None:
exit()
for arg in vars(args):
print('# {} : {}'.format(arg, getattr(args, arg)))
device = torch.device(
'cuda:' + str(args.gpu) if torch.cuda.is_available() else 'cpu') # will be used as "x.to(device)"
torch.cuda.set_device(device) # change allocation of current GPU
# caution!!!! if not "torch.cuda.set_device()":
# RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0
print('Available devices: ', torch.cuda.device_count())
print('Current cuda device: ', torch.cuda.current_device())
print('Current cuda device name: ', torch.cuda.get_device_name(device))
if args.gpu is not None:
print("Use GPU: {} is used".format(args.gpu))
SM = save_manager(args)
""" Initialize a model """
model_net = args.net_object(args).apply(weights_init).to(device)
criterion = [set_rec_loss(args).to(device), set_smoothness_loss().to(device)]
# to enable the inbuilt cudnn auto-tuner
# to find the best algorithm to use for your hardware.
cudnn.benchmark = True
if args.phase == "train":
train(model_net, criterion, device, SM, args)
epoch = args.epochs - 1
elif args.phase == "test" or args.phase == "metrics_evaluation" or args.phase == 'test_custom':
checkpoint = SM.load_model(os.path.join(wrkdir, args.mdl_dir))
2021-08-15 14:32:22 +02:00
model_net.load_state_dict(checkpoint['state_dict_Model'])
epoch = checkpoint['last_epoch']
postfix = '_final_x' + str(args.multiple) + '_S_tst' + str(args.S_tst)
if args.phase != "metrics_evaluation":
print("\n-------------------------------------- Final Test starts -------------------------------------- ")
print('Evaluate on test set (final test) with multiple = %d ' % (args.multiple))
final_test_loader = get_test_data(args, multiple=args.multiple,
validation=False) # multiple is only used for X4K1000FPS
final_pred_save_path = test(final_test_loader, model_net,
criterion, epoch,
args, device,
multiple=args.multiple,
postfix=postfix, validation=False)
#SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR))
if args.dataset == 'X4K1000FPS' and args.phase != 'test_custom':
final_pred_save_path = os.path.join(args.test_img_dir, args.model_dir, 'epoch_' + str(epoch).zfill(5)) + postfix
metrics_evaluation_X_Test(final_pred_save_path, args.test_data_path, args.metrics_types,
flow_flag=args.saving_flow_flag, multiple=args.multiple)
print("------------------------- Test has been ended. -------------------------\n")
quit()
print("Exp:", args.exp_num)
SM = save_manager
multi_scale_recon_loss = criterion[0]
smoothness_loss = criterion[1]
optimizer = optim.Adam(model_net.parameters(), lr=args.init_lr, betas=(0.9, 0.999),
weight_decay=args.weight_decay) # optimizer
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_milestones, gamma=args.lr_dec_fac)
last_epoch = 0
best_PSNR = 0.0
if args.continue_training:
checkpoint = SM.load_model()
last_epoch = checkpoint['last_epoch'] + 1
best_PSNR = checkpoint['best_PSNR']
model_net.load_state_dict(checkpoint['state_dict_Model'])
optimizer.load_state_dict(checkpoint['state_dict_Optimizer'])
scheduler.load_state_dict(checkpoint['state_dict_Scheduler'])
print("Optimizer and Scheduler have been reloaded. ")
scheduler.milestones = Counter(args.lr_milestones)
scheduler.gamma = args.lr_dec_fac
print("scheduler.milestones : {}, scheduler.gamma : {}".format(scheduler.milestones, scheduler.gamma))
start_epoch = last_epoch
# switch to train mode
model_net.train()
start_time = time.time()
#SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n')
#print("[*] Training starts")
# Main training loop for total epochs (start from 'epoch=0')
valid_loader = get_test_data(args, multiple=4, validation=True) # multiple is only used for X4K1000FPS
for epoch in range(start_epoch, args.epochs):
train_loader = get_train_data(args,
max_t_step_size=32) # max_t_step_size (temporal distance) is only used for X4K1000FPS
batch_time = AverageClass('batch_time[s]:', ':6.3f')
losses = AverageClass('Loss:', ':.4e')
progress = ProgressMeter(len(train_loader), batch_time, losses, prefix="Epoch: [{}]".format(epoch))
print('Start epoch {} at [{:s}], learning rate : [{}]'.format(epoch, (str(datetime.now())[:-7]),
optimizer.param_groups[0]['lr']))
# train for one epoch
for trainIndex, (frames, t_value) in enumerate(train_loader):
input_frames = frames[:, :, :-1, :] # [B, C, T, H, W]
frameT = frames[:, :, -1, :] # [B, C, H, W]
# Getting the input and the target from the training set
input_frames = Variable(input_frames.to(device))
frameT = Variable(frameT.to(device)) # ground truth for frameT
t_value = Variable(t_value.to(device)) # [B,1]
optimizer.zero_grad()
# compute output
pred_frameT_pyramid, pred_flow_pyramid, occ_map, simple_mean = model_net(input_frames, t_value)
rec_loss = 0.0
smooth_loss = 0.0
for l, pred_frameT_l in enumerate(pred_frameT_pyramid):
rec_loss += args.rec_lambda * multi_scale_recon_loss(pred_frameT_l,
F.interpolate(frameT, scale_factor=1 / (2 ** l),
mode='bicubic', align_corners=False))
smooth_loss += 0.5 * smoothness_loss(pred_flow_pyramid[0],
F.interpolate(frameT, scale_factor=1 / args.module_scale_factor,
mode='bicubic',
align_corners=False)) # Apply 1st order edge-aware smoothness loss to the fineset level
rec_loss /= len(pred_frameT_pyramid)
pred_frameT = pred_frameT_pyramid[0] # final result I^0_t at original scale (s=0)
pred_coarse_flow = 2 ** (args.S_trn) * F.interpolate(pred_flow_pyramid[-1], scale_factor=2 ** (
args.S_trn) * args.module_scale_factor, mode='bicubic', align_corners=False)
pred_fine_flow = F.interpolate(pred_flow_pyramid[0], scale_factor=args.module_scale_factor, mode='bicubic',
align_corners=False)
total_loss = rec_loss + smooth_loss
# compute gradient and do SGD step
total_loss.backward() # Backpropagate
optimizer.step() # Optimizer update
# measure accumulated time and update average "batch" time consumptions via "AverageClass"
# update average values via "AverageClass"
losses.update(total_loss.item(), 1)
batch_time.update(time.time() - start_time)
start_time = time.time()
if trainIndex % args.freq_display == 0:
progress.print(trainIndex)
batch_images = get_batch_images(args, save_img_num=args.save_img_num,
save_images=[pred_frameT, pred_coarse_flow, pred_fine_flow, frameT,
simple_mean, occ_map])
cv2.imwrite(os.path.join(args.log_dir, '{:03d}_{:04d}_training.png'.format(epoch, trainIndex)), batch_images)
if epoch >= args.lr_dec_start:
scheduler.step()
# if (epoch + 1) % 10 == 0 or epoch==0:
val_multiple = 4 if args.dataset == 'X4K1000FPS' else 2
print('\nEvaluate on test set (validation while training) with multiple = {}'.format(val_multiple))
postfix = '_val_' + str(val_multiple) + '_S_tst' + str(args.S_tst)
final_pred_save_path = test(valid_loader, model_net, criterion, epoch, args,
device, multiple=val_multiple, postfix=postfix,
validation=True)
# remember best best_PSNR and best_SSIM and save checkpoint
#print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR))
best_PSNR_flag = testPSNR > best_PSNR
best_PSNR = max(testPSNR, best_PSNR)
# save checkpoint.
combined_state_dict = {
'net_type': args.net_type,
'last_epoch': epoch,
'batch_size': args.batch_size,
'trainLoss': losses.avg,
'testLoss': testLoss,
'testPSNR': testPSNR,
'best_PSNR': best_PSNR,
'state_dict_Model': model_net.state_dict(),
'state_dict_Optimizer': optimizer.state_dict(),
'state_dict_Scheduler': scheduler.state_dict()}
SM.save_best_model(combined_state_dict, best_PSNR_flag)
if (epoch + 1) % 10 == 0:
SM.save_epc_model(combined_state_dict, epoch)
SM.write_info('{}\t{:.4}\t{:.4}\t{:.4}\n'.format(epoch, losses.avg, testPSNR, best_PSNR))
print("------------------------- Training has been ended. -------------------------\n")
print("information of model:", args.model_dir)
print("best_PSNR of model:", best_PSNR)
def write_src_frame(src_path, target_path, args):
filename, file_ext = os.path.splitext(src_path)
if file_ext == f".{args.img_format}":
shutil.copy(src_path, target_path)
else:
cv2.imwrite(target_path, cv2.imread(src_path))
2021-08-15 14:32:22 +02:00
def test(test_loader, model_net, criterion, epoch, args, device, multiple, postfix, validation):
#os.chdir(interp_output_path)
#batch_time = AverageClass('Time:', ':6.3f')
#losses = AverageClass('testLoss:', ':.4e')
#PSNRs = AverageClass('testPSNR:', ':.4e')
#SSIMs = AverageClass('testSSIM:', ':.4e')
2021-08-15 14:32:22 +02:00
args.divide = 2 ** (args.S_tst) * args.module_scale_factor * 4
# progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
#progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
2021-08-15 14:32:22 +02:00
#multi_scale_recon_loss = criterion[0]
2021-08-15 14:32:22 +02:00
# switch to evaluate mode
model_net.eval()
counter = 1
copied_src_frames = list()
last_frame = ""
print("------------------------------------------- Test ----------------------------------------------")
with torch.no_grad():
start_time = time.time()
for testIndex, (frames, t_value, scene_name, frameRange) in enumerate(test_loader):
# Shape of 'frames' : [1,C,T+1,H,W]
frameT = frames[:, :, -1, :, :] # [1,C,H,W]
It_Path, I0_Path, I1_Path = frameRange
#print(I0_Path)
#print(I1_Path)
input_filename = str(I0_Path).split("'")[1];
input_filename_next = str(I1_Path).split("'")[1];
last_frame = input_filename_next
frameT = Variable(frameT.to(device)) # ground truth for frameT
t_value = Variable(t_value.to(device))
if (testIndex % (multiple - 1)) == 0:
input_frames = frames[:, :, :-1, :, :] # [1,C,T,H,W]
input_frames = Variable(input_frames.to(device))
B, C, T, H, W = input_frames.size()
H_padding = (args.divide - H % args.divide) % args.divide
W_padding = (args.divide - W % args.divide) % args.divide
if H_padding != 0 or W_padding != 0:
input_frames = F.pad(input_frames, (0, W_padding, 0, H_padding), "constant")
pred_frameT = model_net(input_frames, t_value, is_training=False)
if H_padding != 0 or W_padding != 0:
pred_frameT = pred_frameT[:, :, :H, :W]
epoch_save_path = args.custom_path
scene_save_path = os.path.join(epoch_save_path, scene_name[0])
pred_frameT = np.squeeze(pred_frameT.detach().cpu().numpy())
test = np.squeeze(frameT.detach().cpu().numpy())
output_img = np.around(denorm255_np(np.transpose(pred_frameT, [1, 2, 0]))) # [h,w,c] and [-1,1] to [0,255]
#print(os.path.join(scene_save_path, It_Path[0]))
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
src_frame_path = os.path.join(args.custom_path, args.input, input_filename)
if os.path.isfile(src_frame_path):
if src_frame_path in copied_src_frames:
#print(f"Not copying source frame '{src_frame_path}' because it has already been copied before! - {len(copied_src_frames)}")
pass
else:
print(f"S => {os.path.basename(src_frame_path)} => {os.path.basename(frame_src_path)}")
write_src_frame(src_frame_path, frame_src_path, args)
2021-08-15 14:32:22 +02:00
copied_src_frames.append(src_frame_path)
counter += 1
frame_interp_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
print(f"I => {os.path.basename(frame_interp_path)}")
cv2.imwrite(frame_interp_path, output_img.astype(np.uint8))
counter += 1
#losses.update(0.0, 1)
#PSNRs.update(0.0, 1)
#SSIMs.update(0.0, 1)
print("-----------------------------------------------------------------------------------------------")
frame_src_path = os.path.join(args.custom_path, args.output, '{:0>8d}.{}'.format(counter, args.img_format))
print(f"LAST S => {frame_src_path}")
src_frame_path = os.path.join(args.custom_path, args.input, last_frame)
write_src_frame(src_frame_path, frame_src_path, args)
2021-08-15 14:32:22 +02:00
return epoch_save_path
if __name__ == '__main__':
main()