rife-cuda v2: added output arg, updated model to 1.4

This commit is contained in:
N00MKRAD
2020-12-03 00:00:31 +01:00
parent 7219a3e5ad
commit b87a1724d6
10 changed files with 14 additions and 206 deletions

View File

@@ -6,7 +6,7 @@ namespace Flowframes.Data
{
class Networks
{
public static AI rifeCuda = new AI("RIFE_CUDA", "RIFE", "CUDA/Pytorch Implementation of RIFE", Packages.rifeCuda, 1, false);
public static AI rifeCuda = new AI("RIFE_CUDA", "RIFE", "CUDA/Pytorch Implementation of RIFE", Packages.rifeCuda, 2, false);
public static AI rifeNcnn = new AI("RIFE_NCNN", "RIFE (NCNN)", "Vulkan/NCNN Implementation of RIFE", Packages.rifeNcnn, 1, true);
public static AI dainNcnn = new AI("DAIN_NCNN", "DAIN (NCNN)", "Vulkan/NCNN Implementation of DAIN", Packages.dainNcnn, 0, true);
public static AI cainNcnn = new AI("CAIN_NCNN", "CAIN (NCNN)", "Vulkan/NCNN Implementation of CAIN", Packages.cainNcnn, 0, true);

View File

@@ -9,6 +9,6 @@ namespace Flowframes.Data
class Padding
{
public const int inputFrames = 9;
public const int interpFrames = 8;
public const int interpFrames = 8; // TODO: Maybe modify NCNN to accept padding as arg?
}
}

View File

@@ -39,6 +39,7 @@ RIFE_model.device()
parser = argparse.ArgumentParser()
parser.add_argument('--input', required=True)
parser.add_argument('--output', required=False, default='frames-interpolated')
parser.add_argument('--times', default=2, type=int)
parser.add_argument('--imgformat', default="png")
args = parser.parse_args()
@@ -49,7 +50,7 @@ path = args.input
name = os.path.basename(path)
length = len(glob(path + '/*.png'))
#interp_output_path = path.replace(name, name+'-interpolated')
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
interp_output_path = (args.output).join(path.rsplit(name, 1))
os.makedirs(interp_output_path, exist_ok = True)
#output_path = path.replace('tmp', 'output')

View File

@@ -1,141 +0,0 @@
import os
import sys
import cv2
import torch
import argparse
import numpy as np
#from tqdm import tqdm
from torch.nn import functional as F
import warnings
import _thread
abspath = os.path.abspath(__file__)
dname = os.path.dirname(abspath)
print("Changing working dir to {0}".format(dname))
os.chdir(os.path.dirname(dname))
print("Added {0} to PATH".format(dname))
sys.path.append(dname)
warnings.filterwarnings("ignore")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
torch.set_grad_enabled(False)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
else:
print("WARNING: CUDA is not available, RIFE is running on CPU! [ff:nocuda-cpu]")
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--input', required=True)
parser.add_argument('--imgformat', default="png")
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
parser.add_argument('--png', dest='png', default=True, action='store_true', help='whether to output png format outputs')
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension')
parser.add_argument('--times', dest='exp', type=int, default=1, help='interpolation exponent (default: 1)')
args = parser.parse_args()
#assert (args.exp in [1, 2, 3])
#args.times = 2 ** args.exp
from model.RIFE import Model
model = Model()
model.load_model(os.path.join(dname, "models"))
model.eval()
model.device()
videoCapture = cv2.VideoCapture("{}/%08d.png".format(args.input),cv2.CAP_IMAGES)
success, frame = videoCapture.read()
if not (success):
print("fuck")
h, w, _ = frame.shape
path = args.input
name = os.path.basename(path)
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
if not os.path.exists(interp_output_path):
os.mkdir(interp_output_path)
vid_out = None
cnt = 0
skip_frame = 1
buffer = []
def write_frame(i0, infs, i1, p, user_args):
global skip_frame, cnt
for i in range(i0.shape[0]):
# A video transition occurs.
#if p[i] > 0.2:
# if user_args.exp > 1:
# infs = [i0[i] for _ in range(len(infs) - 1)]
# infs[-1] = i1[-1]
# else:
# infs = [i0[i] for _ in range(len(infs))]
# Result was too similar to previous frame, skip if given.
#if p[i] < 5e-3 and user_args.skip:
# if skip_frame % 100 == 0:
# print("Warning: Your video has {} static frames, "
# "skipping them may change the duration of the generated video.".format(skip_frame))
# skip_frame += 1
# continue
# Write results.
buffer.append(i0[i])
for inf in infs:
buffer.append(inf[i])
def clear_buffer(user_args, buffer):
global cnt
for i in buffer:
print("Writing {}/{:0>7d}.{}".format(interp_output_path, cnt, args.imgformat))
cv2.imwrite('{}/{:0>7d}.{}'.format(interp_output_path, cnt, args.imgformat), i)
cnt += 1
def make_inference(model, I0, I1, exp):
middle = model.inference(I0, I1)
if exp == 1:
return [middle]
first_half = make_inference(model, I0, middle, exp=exp - 1)
second_half = make_inference(model, middle, I1, exp=exp - 1)
return [*first_half, middle, *second_half]
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h)
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
print('{} frames in total'.format(tot_frame))
#pbar = tqdm(total=tot_frame)
img_list = [frame]
while success:
success, frame = videoCapture.read()
if success:
img_list.append(frame)
if len(img_list) == 5 or (not success and len(img_list) > 1):
imgs = torch.from_numpy(np.transpose(img_list, (0, 3, 1, 2))).to(device, non_blocking=True).float() / 255.
I0 = imgs[:-1]
I1 = imgs[1:]
p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False)
- F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs()
I0 = F.pad(I0, padding)
I1 = F.pad(I1, padding)
inferences = make_inference(model, I0, I1, exp=args.exp)
I0 = np.array(img_list[:-1])
I1 = np.array(img_list[1:])
inferences = list(map(lambda x: ((x[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)), inferences))
write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
#pbar.update(4)
img_list = img_list[-1:]
if len(buffer) > 100:
_thread.start_new_thread(clear_buffer, (args, buffer))
buffer = []
_thread.start_new_thread(clear_buffer, (args, buffer))
#pbar.close()

View File

@@ -30,8 +30,10 @@ else:
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--input', required=True)
parser.add_argument('--output', required=False, default='frames-interpolated')
parser.add_argument('--imgformat', default="png")
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
#parser.add_argument('--scn', dest='scn', default=False, help='enable scene detection')
#parser.add_argument('--fps', dest='fps', type=int, default=None)
parser.add_argument('--png', dest='png', default=True, help='whether to output png format outputs')
#parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension')
@@ -54,7 +56,9 @@ h, w, _ = frame.shape
path = args.input
name = os.path.basename(path)
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
print('name: ' + name)
interp_output_path = (args.output).join(path.rsplit(name, 1))
print('interp_output_path: ' + interp_output_path)
#if args.fps is None:
# args.fps = fps * args.exptimes
@@ -77,9 +81,10 @@ def write_frame(i0, infs, i1, p, user_args):
l = len(infs)
# A video transition occurs.
#if p[i] > 0.2:
# print('Transition! Duplicting frame instead of interpolating.')
# for j in range(len(infs)):
# infs[j][i] = i0[i]
#
# Result was too similar to previous frame, skip if given.
#if p[i] < 5e-3 and user_args.skip:
# if skip_frame % 100 == 0:
@@ -100,7 +105,7 @@ def clear_buffer(user_args):
if item is None:
break
if user_args.png:
print('Writing {}/{:0>8d}.png'.format(interp_output_path, cnt))
print('=> {:0>8d}.png'.format(cnt))
cv2.imwrite('{}/{:0>8d}.png'.format(interp_output_path, cnt), item[:, :, ::1])
cnt += 1
else:
@@ -147,6 +152,7 @@ buffer.put(img_list[0])
import time
while(not buffer.empty()):
time.sleep(0.1)
time.sleep(0.5)
#pbar.close()
#if not vid_out is None:
# vid_out.release()

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@@ -1,59 +0,0 @@
import sys
import cv2
import os
import numpy as np
import shutil
import torch
import torchvision
from torchvision import transforms
from torch.nn import functional as F
from PIL import Image
from model.RIFE import Model
from glob import glob
from imageio import imread, imsave
from torch.autograd import Variable
RIFE_model = Model()
RIFE_model.load_model('./models')
RIFE_model.eval()
RIFE_model.device()
#print("Input Path: {0}".format(sys.argv[1]))
path = sys.argv[1]
name = os.path.basename(path)
length = len(glob(path + '/*.png'))
interp_output_path = path.replace(name, name+'-interpolated')
os.makedirs(interp_output_path, exist_ok = True)
output_path = path.replace('tmp', 'output')
if os.path.isfile(output_path):
exit
with torch.no_grad():
if not os.path.isfile('{:s}/00000001.png'.format(interp_output_path)):
output_frame_number = 1
for input_frame_number in range(1, length):
print("Interpolating frame {0} of {1}...".format(input_frame_number, length))
frame_0_path = '{:s}/{:08d}.png'.format(path, input_frame_number)
frame_1_path = '{:s}/{:08d}.png'.format(path, input_frame_number + 1)
frame_0 = cv2.imread(frame_0_path)
frame_1 = cv2.imread(frame_1_path)
h, w, _ = frame_0.shape
ph = h // 32 * 32+32
pw = w // 32 * 32+32
padding = (0, pw - w, 0, ph - h)
frame_0 = torch.tensor(frame_0.transpose(2, 0, 1)).cuda() / 255.
frame_1 = torch.tensor(frame_1.transpose(2, 0, 1)).cuda() / 255.
imgs = F.pad(torch.cat((frame_0, frame_1), 0).float(), padding)
res = RIFE_model.inference(imgs.unsqueeze(0)) * 255
shutil.copyfile(frame_0_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))
output_frame_number += 1
cv2.imwrite('{:s}/{:08d}.png'.format(interp_output_path, output_frame_number), res[0].cpu().numpy().transpose(1, 2, 0)[:h, :w])
output_frame_number += 1
if output_frame_number == length*2 - 1:
shutil.copyfile(frame_1_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))
output_frame_number += 1
print("Done!")

View File

@@ -1 +1,2 @@
2 # added --output arg
1 # initial - increased output zero-padding to 8