2021-08-15 14:32:22 +02:00
import argparse , os , shutil , time , random , torch , cv2 , datetime , torch . utils . data , math
import torch . backends . cudnn as cudnn
import torch . optim as optim
import numpy as np
2021-08-24 22:02:41 +02:00
import sys
import os
abspath = os . path . abspath ( __file__ )
wrkdir = os . path . dirname ( abspath )
print ( " Changing working dir to {0} " . format ( wrkdir ) )
os . chdir ( os . path . dirname ( wrkdir ) )
print ( " Added {0} to temporary PATH " . format ( wrkdir ) )
sys . path . append ( wrkdir )
2021-08-15 14:32:22 +02:00
from torch . autograd import Variable
from utils import *
from XVFInet import *
from collections import Counter
def parse_args ( ) :
desc = " PyTorch implementation for XVFI "
parser = argparse . ArgumentParser ( description = desc )
parser . add_argument ( ' --gpu ' , type = int , default = 0 , help = ' gpu index ' )
parser . add_argument ( ' --net_type ' , type = str , default = ' XVFInet ' , choices = [ ' XVFInet ' ] , help = ' The type of Net ' )
parser . add_argument ( ' --net_object ' , default = XVFInet , choices = [ XVFInet ] , help = ' The type of Net ' )
parser . add_argument ( ' --exp_num ' , type = int , default = 1 , help = ' The experiment number ' )
parser . add_argument ( ' --phase ' , type = str , default = ' test_custom ' , choices = [ ' train ' , ' test ' , ' test_custom ' , ' metrics_evaluation ' , ] )
parser . add_argument ( ' --continue_training ' , action = ' store_true ' , default = False , help = ' continue the training ' )
""" Information of directories """
parser . add_argument ( ' --test_img_dir ' , type = str , default = ' ./test_img_dir ' , help = ' test_img_dir path ' )
parser . add_argument ( ' --text_dir ' , type = str , default = ' ./text_dir ' , help = ' text_dir path ' )
parser . add_argument ( ' --checkpoint_dir ' , type = str , default = ' ./checkpoint_dir ' , help = ' checkpoint_dir ' )
parser . add_argument ( ' --log_dir ' , type = str , default = ' ./log_dir ' , help = ' Directory name to save training logs ' )
parser . add_argument ( ' --dataset ' , default = ' X4K1000FPS ' , choices = [ ' X4K1000FPS ' , ' Vimeo ' ] ,
help = ' Training/test Dataset ' )
# parser.add_argument('--train_data_path', type=str, default='./X4K1000FPS/train')
# parser.add_argument('--val_data_path', type=str, default='./X4K1000FPS/val')
# parser.add_argument('--test_data_path', type=str, default='./X4K1000FPS/test')
parser . add_argument ( ' --train_data_path ' , type = str , default = ' ../Datasets/VIC_4K_1000FPS/train ' )
parser . add_argument ( ' --val_data_path ' , type = str , default = ' ../Datasets/VIC_4K_1000FPS/val ' )
parser . add_argument ( ' --test_data_path ' , type = str , default = ' ../Datasets/VIC_4K_1000FPS/test ' )
parser . add_argument ( ' --vimeo_data_path ' , type = str , default = ' ./vimeo_triplet ' )
""" Hyperparameters for Training (when [phase== ' train ' ]) """
parser . add_argument ( ' --epochs ' , type = int , default = 200 , help = ' The number of epochs to run ' )
parser . add_argument ( ' --freq_display ' , type = int , default = 100 , help = ' The number of iterations frequency for display ' )
parser . add_argument ( ' --save_img_num ' , type = int , default = 4 ,
help = ' The number of saved image while training for visualization. It should smaller than the batch_size ' )
parser . add_argument ( ' --init_lr ' , type = float , default = 1e-4 , help = ' The initial learning rate ' )
parser . add_argument ( ' --lr_dec_fac ' , type = float , default = 0.25 , help = ' step - lr_decreasing_factor ' )
parser . add_argument ( ' --lr_milestones ' , type = int , default = [ 100 , 150 , 180 ] )
parser . add_argument ( ' --lr_dec_start ' , type = int , default = 0 ,
help = ' When scheduler is StepLR, lr decreases from epoch at lr_dec_start ' )
parser . add_argument ( ' --batch_size ' , type = int , default = 8 , help = ' The size of batch size. ' )
parser . add_argument ( ' --weight_decay ' , type = float , default = 0 , help = ' for optim., weight decay (default: 0) ' )
parser . add_argument ( ' --need_patch ' , default = True , help = ' get patch form image while training ' )
parser . add_argument ( ' --img_ch ' , type = int , default = 3 , help = ' base number of channels for image ' )
parser . add_argument ( ' --nf ' , type = int , default = 64 , help = ' base number of channels for feature maps ' ) # 64
parser . add_argument ( ' --module_scale_factor ' , type = int , default = 4 , help = ' sptial reduction for pixelshuffle ' )
parser . add_argument ( ' --patch_size ' , type = int , default = 384 , help = ' patch size ' )
parser . add_argument ( ' --num_thrds ' , type = int , default = 4 , help = ' number of threads for data loading ' )
parser . add_argument ( ' --loss_type ' , default = ' L1 ' , choices = [ ' L1 ' , ' MSE ' , ' L1_Charbonnier_loss ' ] , help = ' Loss type ' )
parser . add_argument ( ' --S_trn ' , type = int , default = 3 , help = ' The lowest scale depth for training ' )
parser . add_argument ( ' --S_tst ' , type = int , default = 5 , help = ' The lowest scale depth for test ' )
""" Weighting Parameters Lambda for Losses (when [phase== ' train ' ]) """
parser . add_argument ( ' --rec_lambda ' , type = float , default = 1.0 , help = ' Lambda for Reconstruction Loss ' )
""" Settings for Testing (when [phase== ' test ' or ' test_custom ' ]) """
parser . add_argument ( ' --saving_flow_flag ' , default = False )
parser . add_argument ( ' --multiple ' , type = int , default = 8 , help = ' Due to the indexing problem of the file names, we recommend to use the power of 2. (e.g. 2, 4, 8, 16 ...). CAUTION : For the provided X-TEST, multiple should be one of [2, 4, 8, 16, 32]. ' )
parser . add_argument ( ' --metrics_types ' , type = list , default = [ " PSNR " , " SSIM " , " tOF " ] , choices = [ " PSNR " , " SSIM " , " tOF " ] )
""" Settings for test_custom (when [phase== ' test_custom ' ]) """
parser . add_argument ( ' --custom_path ' , type = str , default = ' ./custom_path ' , help = ' path for custom video containing frames ' )
parser . add_argument ( ' --output ' , type = str , default = ' ./interp ' , help = ' output path ' )
parser . add_argument ( ' --input ' , type = str , default = ' ./frames ' , help = ' input path ' )
parser . add_argument ( ' --img_format ' , type = str , default = " png " )
parser . add_argument ( ' --mdl_dir ' , type = str )
return check_args ( parser . parse_args ( ) )
def check_args ( args ) :
# --checkpoint_dir
check_folder ( args . checkpoint_dir )
# --text_dir
check_folder ( args . text_dir )
# --log_dir
check_folder ( args . log_dir )
# --test_img_dir
check_folder ( args . test_img_dir )
return args
def main ( ) :
args = parse_args ( )
if args . dataset == ' Vimeo ' :
if args . phase != ' test_custom ' :
args . multiple = 2
args . S_trn = 1
args . S_tst = 1
args . module_scale_factor = 2
args . patch_size = 256
args . batch_size = 16
print ( ' vimeo triplet data dir : ' , args . vimeo_data_path )
print ( " Exp: " , args . exp_num )
args . model_dir = args . net_type + ' _ ' + args . dataset + ' _exp ' + str (
args . exp_num ) # ex) model_dir = "XVFInet_X4K1000FPS_exp1"
if args is None :
exit ( )
for arg in vars ( args ) :
print ( ' # {} : {} ' . format ( arg , getattr ( args , arg ) ) )
device = torch . device (
' cuda: ' + str ( args . gpu ) if torch . cuda . is_available ( ) else ' cpu ' ) # will be used as "x.to(device)"
torch . cuda . set_device ( device ) # change allocation of current GPU
# caution!!!! if not "torch.cuda.set_device()":
# RuntimeError: grid_sampler(): expected input and grid to be on same device, but input is on cuda:1 and grid is on cuda:0
print ( ' Available devices: ' , torch . cuda . device_count ( ) )
print ( ' Current cuda device: ' , torch . cuda . current_device ( ) )
print ( ' Current cuda device name: ' , torch . cuda . get_device_name ( device ) )
if args . gpu is not None :
print ( " Use GPU: {} is used " . format ( args . gpu ) )
SM = save_manager ( args )
""" Initialize a model """
model_net = args . net_object ( args ) . apply ( weights_init ) . to ( device )
criterion = [ set_rec_loss ( args ) . to ( device ) , set_smoothness_loss ( ) . to ( device ) ]
# to enable the inbuilt cudnn auto-tuner
# to find the best algorithm to use for your hardware.
cudnn . benchmark = True
if args . phase == " train " :
train ( model_net , criterion , device , SM , args )
epoch = args . epochs - 1
elif args . phase == " test " or args . phase == " metrics_evaluation " or args . phase == ' test_custom ' :
2021-08-24 22:02:41 +02:00
checkpoint = SM . load_model ( os . path . join ( wrkdir , args . mdl_dir ) )
2021-08-15 14:32:22 +02:00
model_net . load_state_dict ( checkpoint [ ' state_dict_Model ' ] )
epoch = checkpoint [ ' last_epoch ' ]
postfix = ' _final_x ' + str ( args . multiple ) + ' _S_tst ' + str ( args . S_tst )
if args . phase != " metrics_evaluation " :
print ( " \n -------------------------------------- Final Test starts -------------------------------------- " )
print ( ' Evaluate on test set (final test) with multiple = %d ' % ( args . multiple ) )
final_test_loader = get_test_data ( args , multiple = args . multiple ,
validation = False ) # multiple is only used for X4K1000FPS
final_pred_save_path = test ( final_test_loader , model_net ,
criterion , epoch ,
args , device ,
multiple = args . multiple ,
postfix = postfix , validation = False )
#SM.write_info('Final 4k frames PSNR : {:.4}\n'.format(testPSNR))
if args . dataset == ' X4K1000FPS ' and args . phase != ' test_custom ' :
final_pred_save_path = os . path . join ( args . test_img_dir , args . model_dir , ' epoch_ ' + str ( epoch ) . zfill ( 5 ) ) + postfix
metrics_evaluation_X_Test ( final_pred_save_path , args . test_data_path , args . metrics_types ,
flow_flag = args . saving_flow_flag , multiple = args . multiple )
print ( " ------------------------- Test has been ended. ------------------------- \n " )
quit ( )
print ( " Exp: " , args . exp_num )
SM = save_manager
multi_scale_recon_loss = criterion [ 0 ]
smoothness_loss = criterion [ 1 ]
optimizer = optim . Adam ( model_net . parameters ( ) , lr = args . init_lr , betas = ( 0.9 , 0.999 ) ,
weight_decay = args . weight_decay ) # optimizer
scheduler = optim . lr_scheduler . MultiStepLR ( optimizer , milestones = args . lr_milestones , gamma = args . lr_dec_fac )
last_epoch = 0
best_PSNR = 0.0
if args . continue_training :
checkpoint = SM . load_model ( )
last_epoch = checkpoint [ ' last_epoch ' ] + 1
best_PSNR = checkpoint [ ' best_PSNR ' ]
model_net . load_state_dict ( checkpoint [ ' state_dict_Model ' ] )
optimizer . load_state_dict ( checkpoint [ ' state_dict_Optimizer ' ] )
scheduler . load_state_dict ( checkpoint [ ' state_dict_Scheduler ' ] )
print ( " Optimizer and Scheduler have been reloaded. " )
scheduler . milestones = Counter ( args . lr_milestones )
scheduler . gamma = args . lr_dec_fac
print ( " scheduler.milestones : {} , scheduler.gamma : {} " . format ( scheduler . milestones , scheduler . gamma ) )
start_epoch = last_epoch
# switch to train mode
model_net . train ( )
start_time = time . time ( )
#SM.write_info('Epoch\ttrainLoss\ttestPSNR\tbest_PSNR\n')
#print("[*] Training starts")
# Main training loop for total epochs (start from 'epoch=0')
valid_loader = get_test_data ( args , multiple = 4 , validation = True ) # multiple is only used for X4K1000FPS
for epoch in range ( start_epoch , args . epochs ) :
train_loader = get_train_data ( args ,
max_t_step_size = 32 ) # max_t_step_size (temporal distance) is only used for X4K1000FPS
batch_time = AverageClass ( ' batch_time[s]: ' , ' :6.3f ' )
losses = AverageClass ( ' Loss: ' , ' :.4e ' )
progress = ProgressMeter ( len ( train_loader ) , batch_time , losses , prefix = " Epoch: [ {} ] " . format ( epoch ) )
print ( ' Start epoch {} at [ {:s} ], learning rate : [ {} ] ' . format ( epoch , ( str ( datetime . now ( ) ) [ : - 7 ] ) ,
optimizer . param_groups [ 0 ] [ ' lr ' ] ) )
# train for one epoch
for trainIndex , ( frames , t_value ) in enumerate ( train_loader ) :
input_frames = frames [ : , : , : - 1 , : ] # [B, C, T, H, W]
frameT = frames [ : , : , - 1 , : ] # [B, C, H, W]
# Getting the input and the target from the training set
input_frames = Variable ( input_frames . to ( device ) )
frameT = Variable ( frameT . to ( device ) ) # ground truth for frameT
t_value = Variable ( t_value . to ( device ) ) # [B,1]
optimizer . zero_grad ( )
# compute output
pred_frameT_pyramid , pred_flow_pyramid , occ_map , simple_mean = model_net ( input_frames , t_value )
rec_loss = 0.0
smooth_loss = 0.0
for l , pred_frameT_l in enumerate ( pred_frameT_pyramid ) :
rec_loss + = args . rec_lambda * multi_scale_recon_loss ( pred_frameT_l ,
F . interpolate ( frameT , scale_factor = 1 / ( 2 * * l ) ,
mode = ' bicubic ' , align_corners = False ) )
smooth_loss + = 0.5 * smoothness_loss ( pred_flow_pyramid [ 0 ] ,
F . interpolate ( frameT , scale_factor = 1 / args . module_scale_factor ,
mode = ' bicubic ' ,
align_corners = False ) ) # Apply 1st order edge-aware smoothness loss to the fineset level
rec_loss / = len ( pred_frameT_pyramid )
pred_frameT = pred_frameT_pyramid [ 0 ] # final result I^0_t at original scale (s=0)
pred_coarse_flow = 2 * * ( args . S_trn ) * F . interpolate ( pred_flow_pyramid [ - 1 ] , scale_factor = 2 * * (
args . S_trn ) * args . module_scale_factor , mode = ' bicubic ' , align_corners = False )
pred_fine_flow = F . interpolate ( pred_flow_pyramid [ 0 ] , scale_factor = args . module_scale_factor , mode = ' bicubic ' ,
align_corners = False )
total_loss = rec_loss + smooth_loss
# compute gradient and do SGD step
total_loss . backward ( ) # Backpropagate
optimizer . step ( ) # Optimizer update
# measure accumulated time and update average "batch" time consumptions via "AverageClass"
# update average values via "AverageClass"
losses . update ( total_loss . item ( ) , 1 )
batch_time . update ( time . time ( ) - start_time )
start_time = time . time ( )
if trainIndex % args . freq_display == 0 :
progress . print ( trainIndex )
batch_images = get_batch_images ( args , save_img_num = args . save_img_num ,
save_images = [ pred_frameT , pred_coarse_flow , pred_fine_flow , frameT ,
simple_mean , occ_map ] )
cv2 . imwrite ( os . path . join ( args . log_dir , ' {:03d} _ {:04d} _training.png ' . format ( epoch , trainIndex ) ) , batch_images )
if epoch > = args . lr_dec_start :
scheduler . step ( )
# if (epoch + 1) % 10 == 0 or epoch==0:
val_multiple = 4 if args . dataset == ' X4K1000FPS ' else 2
print ( ' \n Evaluate on test set (validation while training) with multiple = {} ' . format ( val_multiple ) )
postfix = ' _val_ ' + str ( val_multiple ) + ' _S_tst ' + str ( args . S_tst )
final_pred_save_path = test ( valid_loader , model_net , criterion , epoch , args ,
device , multiple = val_multiple , postfix = postfix ,
validation = True )
# remember best best_PSNR and best_SSIM and save checkpoint
#print("best_PSNR : {:.3f}, testPSNR : {:.3f}".format(best_PSNR, testPSNR))
best_PSNR_flag = testPSNR > best_PSNR
best_PSNR = max ( testPSNR , best_PSNR )
# save checkpoint.
combined_state_dict = {
' net_type ' : args . net_type ,
' last_epoch ' : epoch ,
' batch_size ' : args . batch_size ,
' trainLoss ' : losses . avg ,
' testLoss ' : testLoss ,
' testPSNR ' : testPSNR ,
' best_PSNR ' : best_PSNR ,
' state_dict_Model ' : model_net . state_dict ( ) ,
' state_dict_Optimizer ' : optimizer . state_dict ( ) ,
' state_dict_Scheduler ' : scheduler . state_dict ( ) }
SM . save_best_model ( combined_state_dict , best_PSNR_flag )
if ( epoch + 1 ) % 10 == 0 :
SM . save_epc_model ( combined_state_dict , epoch )
SM . write_info ( ' {} \t {:.4} \t {:.4} \t {:.4} \n ' . format ( epoch , losses . avg , testPSNR , best_PSNR ) )
print ( " ------------------------- Training has been ended. ------------------------- \n " )
print ( " information of model: " , args . model_dir )
print ( " best_PSNR of model: " , best_PSNR )
2021-08-24 22:02:41 +02:00
def write_src_frame ( src_path , target_path , args ) :
filename , file_ext = os . path . splitext ( src_path )
if file_ext == f " . { args . img_format } " :
shutil . copy ( src_path , target_path )
else :
cv2 . imwrite ( target_path , cv2 . imread ( src_path ) )
2021-08-15 14:32:22 +02:00
def test ( test_loader , model_net , criterion , epoch , args , device , multiple , postfix , validation ) :
#os.chdir(interp_output_path)
2021-08-24 22:02:41 +02:00
#batch_time = AverageClass('Time:', ':6.3f')
#losses = AverageClass('testLoss:', ':.4e')
#PSNRs = AverageClass('testPSNR:', ':.4e')
#SSIMs = AverageClass('testSSIM:', ':.4e')
2021-08-15 14:32:22 +02:00
args . divide = 2 * * ( args . S_tst ) * args . module_scale_factor * 4
# progress = ProgressMeter(len(test_loader), batch_time, accm_time, losses, PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
2021-08-24 22:02:41 +02:00
#progress = ProgressMeter(len(test_loader), PSNRs, SSIMs, prefix='Test after Epoch[{}]: '.format(epoch))
2021-08-15 14:32:22 +02:00
2021-08-24 22:02:41 +02:00
#multi_scale_recon_loss = criterion[0]
2021-08-15 14:32:22 +02:00
# switch to evaluate mode
model_net . eval ( )
counter = 1
copied_src_frames = list ( )
last_frame = " "
print ( " ------------------------------------------- Test ---------------------------------------------- " )
with torch . no_grad ( ) :
start_time = time . time ( )
for testIndex , ( frames , t_value , scene_name , frameRange ) in enumerate ( test_loader ) :
# Shape of 'frames' : [1,C,T+1,H,W]
frameT = frames [ : , : , - 1 , : , : ] # [1,C,H,W]
It_Path , I0_Path , I1_Path = frameRange
#print(I0_Path)
#print(I1_Path)
input_filename = str ( I0_Path ) . split ( " ' " ) [ 1 ] ;
input_filename_next = str ( I1_Path ) . split ( " ' " ) [ 1 ] ;
last_frame = input_filename_next
frameT = Variable ( frameT . to ( device ) ) # ground truth for frameT
t_value = Variable ( t_value . to ( device ) )
if ( testIndex % ( multiple - 1 ) ) == 0 :
input_frames = frames [ : , : , : - 1 , : , : ] # [1,C,T,H,W]
input_frames = Variable ( input_frames . to ( device ) )
B , C , T , H , W = input_frames . size ( )
H_padding = ( args . divide - H % args . divide ) % args . divide
W_padding = ( args . divide - W % args . divide ) % args . divide
if H_padding != 0 or W_padding != 0 :
input_frames = F . pad ( input_frames , ( 0 , W_padding , 0 , H_padding ) , " constant " )
pred_frameT = model_net ( input_frames , t_value , is_training = False )
if H_padding != 0 or W_padding != 0 :
pred_frameT = pred_frameT [ : , : , : H , : W ]
epoch_save_path = args . custom_path
scene_save_path = os . path . join ( epoch_save_path , scene_name [ 0 ] )
pred_frameT = np . squeeze ( pred_frameT . detach ( ) . cpu ( ) . numpy ( ) )
test = np . squeeze ( frameT . detach ( ) . cpu ( ) . numpy ( ) )
output_img = np . around ( denorm255_np ( np . transpose ( pred_frameT , [ 1 , 2 , 0 ] ) ) ) # [h,w,c] and [-1,1] to [0,255]
#print(os.path.join(scene_save_path, It_Path[0]))
frame_src_path = os . path . join ( args . custom_path , args . output , ' {:0>8d} . {} ' . format ( counter , args . img_format ) )
src_frame_path = os . path . join ( args . custom_path , args . input , input_filename )
if os . path . isfile ( src_frame_path ) :
if src_frame_path in copied_src_frames :
#print(f"Not copying source frame '{src_frame_path}' because it has already been copied before! - {len(copied_src_frames)}")
pass
else :
print ( f " S => { os . path . basename ( src_frame_path ) } => { os . path . basename ( frame_src_path ) } " )
2021-08-24 22:02:41 +02:00
write_src_frame ( src_frame_path , frame_src_path , args )
2021-08-15 14:32:22 +02:00
copied_src_frames . append ( src_frame_path )
counter + = 1
frame_interp_path = os . path . join ( args . custom_path , args . output , ' {:0>8d} . {} ' . format ( counter , args . img_format ) )
print ( f " I => { os . path . basename ( frame_interp_path ) } " )
cv2 . imwrite ( frame_interp_path , output_img . astype ( np . uint8 ) )
counter + = 1
#losses.update(0.0, 1)
#PSNRs.update(0.0, 1)
#SSIMs.update(0.0, 1)
print ( " ----------------------------------------------------------------------------------------------- " )
frame_src_path = os . path . join ( args . custom_path , args . output , ' {:0>8d} . {} ' . format ( counter , args . img_format ) )
print ( f " LAST S => { frame_src_path } " )
src_frame_path = os . path . join ( args . custom_path , args . input , last_frame )
2021-08-24 22:02:41 +02:00
write_src_frame ( src_frame_path , frame_src_path , args )
2021-08-15 14:32:22 +02:00
return epoch_save_path
if __name__ == ' __main__ ' :
main ( )