mirror of
https://github.com/n00mkrad/flowframes.git
synced 2025-12-16 16:37:48 +01:00
rife-cuda v2: added output arg, updated model to 1.4
This commit is contained in:
@@ -6,7 +6,7 @@ namespace Flowframes.Data
|
||||
{
|
||||
class Networks
|
||||
{
|
||||
public static AI rifeCuda = new AI("RIFE_CUDA", "RIFE", "CUDA/Pytorch Implementation of RIFE", Packages.rifeCuda, 1, false);
|
||||
public static AI rifeCuda = new AI("RIFE_CUDA", "RIFE", "CUDA/Pytorch Implementation of RIFE", Packages.rifeCuda, 2, false);
|
||||
public static AI rifeNcnn = new AI("RIFE_NCNN", "RIFE (NCNN)", "Vulkan/NCNN Implementation of RIFE", Packages.rifeNcnn, 1, true);
|
||||
public static AI dainNcnn = new AI("DAIN_NCNN", "DAIN (NCNN)", "Vulkan/NCNN Implementation of DAIN", Packages.dainNcnn, 0, true);
|
||||
public static AI cainNcnn = new AI("CAIN_NCNN", "CAIN (NCNN)", "Vulkan/NCNN Implementation of CAIN", Packages.cainNcnn, 0, true);
|
||||
|
||||
@@ -9,6 +9,6 @@ namespace Flowframes.Data
|
||||
class Padding
|
||||
{
|
||||
public const int inputFrames = 9;
|
||||
public const int interpFrames = 8;
|
||||
public const int interpFrames = 8; // TODO: Maybe modify NCNN to accept padding as arg?
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,6 +39,7 @@ RIFE_model.device()
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--output', required=False, default='frames-interpolated')
|
||||
parser.add_argument('--times', default=2, type=int)
|
||||
parser.add_argument('--imgformat', default="png")
|
||||
args = parser.parse_args()
|
||||
@@ -49,7 +50,7 @@ path = args.input
|
||||
name = os.path.basename(path)
|
||||
length = len(glob(path + '/*.png'))
|
||||
#interp_output_path = path.replace(name, name+'-interpolated')
|
||||
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
|
||||
interp_output_path = (args.output).join(path.rsplit(name, 1))
|
||||
os.makedirs(interp_output_path, exist_ok = True)
|
||||
#output_path = path.replace('tmp', 'output')
|
||||
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import torch
|
||||
import argparse
|
||||
import numpy as np
|
||||
#from tqdm import tqdm
|
||||
from torch.nn import functional as F
|
||||
import warnings
|
||||
import _thread
|
||||
|
||||
abspath = os.path.abspath(__file__)
|
||||
dname = os.path.dirname(abspath)
|
||||
print("Changing working dir to {0}".format(dname))
|
||||
os.chdir(os.path.dirname(dname))
|
||||
print("Added {0} to PATH".format(dname))
|
||||
sys.path.append(dname)
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.set_grad_enabled(False)
|
||||
torch.backends.cudnn.enabled = True
|
||||
torch.backends.cudnn.benchmark = True
|
||||
else:
|
||||
print("WARNING: CUDA is not available, RIFE is running on CPU! [ff:nocuda-cpu]")
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--imgformat', default="png")
|
||||
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
||||
parser.add_argument('--png', dest='png', default=True, action='store_true', help='whether to output png format outputs')
|
||||
parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension')
|
||||
parser.add_argument('--times', dest='exp', type=int, default=1, help='interpolation exponent (default: 1)')
|
||||
args = parser.parse_args()
|
||||
#assert (args.exp in [1, 2, 3])
|
||||
#args.times = 2 ** args.exp
|
||||
|
||||
from model.RIFE import Model
|
||||
model = Model()
|
||||
model.load_model(os.path.join(dname, "models"))
|
||||
model.eval()
|
||||
model.device()
|
||||
|
||||
|
||||
videoCapture = cv2.VideoCapture("{}/%08d.png".format(args.input),cv2.CAP_IMAGES)
|
||||
success, frame = videoCapture.read()
|
||||
|
||||
if not (success):
|
||||
print("fuck")
|
||||
|
||||
h, w, _ = frame.shape
|
||||
|
||||
path = args.input
|
||||
name = os.path.basename(path)
|
||||
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
|
||||
|
||||
if not os.path.exists(interp_output_path):
|
||||
os.mkdir(interp_output_path)
|
||||
vid_out = None
|
||||
|
||||
cnt = 0
|
||||
skip_frame = 1
|
||||
buffer = []
|
||||
|
||||
def write_frame(i0, infs, i1, p, user_args):
|
||||
global skip_frame, cnt
|
||||
for i in range(i0.shape[0]):
|
||||
# A video transition occurs.
|
||||
#if p[i] > 0.2:
|
||||
# if user_args.exp > 1:
|
||||
# infs = [i0[i] for _ in range(len(infs) - 1)]
|
||||
# infs[-1] = i1[-1]
|
||||
# else:
|
||||
# infs = [i0[i] for _ in range(len(infs))]
|
||||
|
||||
# Result was too similar to previous frame, skip if given.
|
||||
#if p[i] < 5e-3 and user_args.skip:
|
||||
# if skip_frame % 100 == 0:
|
||||
# print("Warning: Your video has {} static frames, "
|
||||
# "skipping them may change the duration of the generated video.".format(skip_frame))
|
||||
# skip_frame += 1
|
||||
# continue
|
||||
|
||||
# Write results.
|
||||
buffer.append(i0[i])
|
||||
for inf in infs:
|
||||
buffer.append(inf[i])
|
||||
|
||||
def clear_buffer(user_args, buffer):
|
||||
global cnt
|
||||
for i in buffer:
|
||||
print("Writing {}/{:0>7d}.{}".format(interp_output_path, cnt, args.imgformat))
|
||||
cv2.imwrite('{}/{:0>7d}.{}'.format(interp_output_path, cnt, args.imgformat), i)
|
||||
cnt += 1
|
||||
|
||||
def make_inference(model, I0, I1, exp):
|
||||
middle = model.inference(I0, I1)
|
||||
if exp == 1:
|
||||
return [middle]
|
||||
first_half = make_inference(model, I0, middle, exp=exp - 1)
|
||||
second_half = make_inference(model, middle, I1, exp=exp - 1)
|
||||
return [*first_half, middle, *second_half]
|
||||
|
||||
|
||||
ph = ((h - 1) // 32 + 1) * 32
|
||||
pw = ((w - 1) // 32 + 1) * 32
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
tot_frame = videoCapture.get(cv2.CAP_PROP_FRAME_COUNT)
|
||||
print('{} frames in total'.format(tot_frame))
|
||||
#pbar = tqdm(total=tot_frame)
|
||||
img_list = [frame]
|
||||
while success:
|
||||
success, frame = videoCapture.read()
|
||||
if success:
|
||||
img_list.append(frame)
|
||||
if len(img_list) == 5 or (not success and len(img_list) > 1):
|
||||
imgs = torch.from_numpy(np.transpose(img_list, (0, 3, 1, 2))).to(device, non_blocking=True).float() / 255.
|
||||
I0 = imgs[:-1]
|
||||
I1 = imgs[1:]
|
||||
p = (F.interpolate(I0, (16, 16), mode='bilinear', align_corners=False)
|
||||
- F.interpolate(I1, (16, 16), mode='bilinear', align_corners=False)).abs()
|
||||
I0 = F.pad(I0, padding)
|
||||
I1 = F.pad(I1, padding)
|
||||
inferences = make_inference(model, I0, I1, exp=args.exp)
|
||||
|
||||
I0 = np.array(img_list[:-1])
|
||||
I1 = np.array(img_list[1:])
|
||||
inferences = list(map(lambda x: ((x[:, :, :h, :w] * 255.).byte().cpu().detach().numpy().transpose(0, 2, 3, 1)), inferences))
|
||||
|
||||
write_frame(I0, inferences, I1, p.mean(3).mean(2).mean(1), args)
|
||||
#pbar.update(4)
|
||||
img_list = img_list[-1:]
|
||||
if len(buffer) > 100:
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
buffer = []
|
||||
_thread.start_new_thread(clear_buffer, (args, buffer))
|
||||
#pbar.close()
|
||||
@@ -30,8 +30,10 @@ else:
|
||||
|
||||
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
|
||||
parser.add_argument('--input', required=True)
|
||||
parser.add_argument('--output', required=False, default='frames-interpolated')
|
||||
parser.add_argument('--imgformat', default="png")
|
||||
parser.add_argument('--skip', dest='skip', action='store_true', help='whether to remove static frames before processing')
|
||||
#parser.add_argument('--scn', dest='scn', default=False, help='enable scene detection')
|
||||
#parser.add_argument('--fps', dest='fps', type=int, default=None)
|
||||
parser.add_argument('--png', dest='png', default=True, help='whether to output png format outputs')
|
||||
#parser.add_argument('--ext', dest='ext', type=str, default='mp4', help='output video extension')
|
||||
@@ -54,7 +56,9 @@ h, w, _ = frame.shape
|
||||
|
||||
path = args.input
|
||||
name = os.path.basename(path)
|
||||
interp_output_path = (name+'-interpolated').join(path.rsplit(name, 1))
|
||||
print('name: ' + name)
|
||||
interp_output_path = (args.output).join(path.rsplit(name, 1))
|
||||
print('interp_output_path: ' + interp_output_path)
|
||||
|
||||
#if args.fps is None:
|
||||
# args.fps = fps * args.exptimes
|
||||
@@ -77,9 +81,10 @@ def write_frame(i0, infs, i1, p, user_args):
|
||||
l = len(infs)
|
||||
# A video transition occurs.
|
||||
#if p[i] > 0.2:
|
||||
# print('Transition! Duplicting frame instead of interpolating.')
|
||||
# for j in range(len(infs)):
|
||||
# infs[j][i] = i0[i]
|
||||
#
|
||||
|
||||
# Result was too similar to previous frame, skip if given.
|
||||
#if p[i] < 5e-3 and user_args.skip:
|
||||
# if skip_frame % 100 == 0:
|
||||
@@ -100,7 +105,7 @@ def clear_buffer(user_args):
|
||||
if item is None:
|
||||
break
|
||||
if user_args.png:
|
||||
print('Writing {}/{:0>8d}.png'.format(interp_output_path, cnt))
|
||||
print('=> {:0>8d}.png'.format(cnt))
|
||||
cv2.imwrite('{}/{:0>8d}.png'.format(interp_output_path, cnt), item[:, :, ::1])
|
||||
cnt += 1
|
||||
else:
|
||||
@@ -147,6 +152,7 @@ buffer.put(img_list[0])
|
||||
import time
|
||||
while(not buffer.empty()):
|
||||
time.sleep(0.1)
|
||||
time.sleep(0.5)
|
||||
#pbar.close()
|
||||
#if not vid_out is None:
|
||||
# vid_out.release()
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,59 +0,0 @@
|
||||
import sys
|
||||
import cv2
|
||||
import os
|
||||
import numpy as np
|
||||
import shutil
|
||||
import torch
|
||||
import torchvision
|
||||
from torchvision import transforms
|
||||
from torch.nn import functional as F
|
||||
from PIL import Image
|
||||
from model.RIFE import Model
|
||||
from glob import glob
|
||||
from imageio import imread, imsave
|
||||
from torch.autograd import Variable
|
||||
|
||||
RIFE_model = Model()
|
||||
RIFE_model.load_model('./models')
|
||||
RIFE_model.eval()
|
||||
RIFE_model.device()
|
||||
|
||||
#print("Input Path: {0}".format(sys.argv[1]))
|
||||
path = sys.argv[1]
|
||||
|
||||
name = os.path.basename(path)
|
||||
length = len(glob(path + '/*.png'))
|
||||
interp_output_path = path.replace(name, name+'-interpolated')
|
||||
os.makedirs(interp_output_path, exist_ok = True)
|
||||
output_path = path.replace('tmp', 'output')
|
||||
if os.path.isfile(output_path):
|
||||
exit
|
||||
|
||||
with torch.no_grad():
|
||||
if not os.path.isfile('{:s}/00000001.png'.format(interp_output_path)):
|
||||
output_frame_number = 1
|
||||
for input_frame_number in range(1, length):
|
||||
print("Interpolating frame {0} of {1}...".format(input_frame_number, length))
|
||||
frame_0_path = '{:s}/{:08d}.png'.format(path, input_frame_number)
|
||||
frame_1_path = '{:s}/{:08d}.png'.format(path, input_frame_number + 1)
|
||||
frame_0 = cv2.imread(frame_0_path)
|
||||
frame_1 = cv2.imread(frame_1_path)
|
||||
|
||||
h, w, _ = frame_0.shape
|
||||
ph = h // 32 * 32+32
|
||||
pw = w // 32 * 32+32
|
||||
padding = (0, pw - w, 0, ph - h)
|
||||
frame_0 = torch.tensor(frame_0.transpose(2, 0, 1)).cuda() / 255.
|
||||
frame_1 = torch.tensor(frame_1.transpose(2, 0, 1)).cuda() / 255.
|
||||
imgs = F.pad(torch.cat((frame_0, frame_1), 0).float(), padding)
|
||||
res = RIFE_model.inference(imgs.unsqueeze(0)) * 255
|
||||
|
||||
shutil.copyfile(frame_0_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))
|
||||
output_frame_number += 1
|
||||
cv2.imwrite('{:s}/{:08d}.png'.format(interp_output_path, output_frame_number), res[0].cpu().numpy().transpose(1, 2, 0)[:h, :w])
|
||||
output_frame_number += 1
|
||||
|
||||
if output_frame_number == length*2 - 1:
|
||||
shutil.copyfile(frame_1_path, '{:s}/{:08d}.png'.format(interp_output_path, output_frame_number))
|
||||
output_frame_number += 1
|
||||
print("Done!")
|
||||
@@ -1 +1,2 @@
|
||||
2 # added --output arg
|
||||
1 # initial - increased output zero-padding to 8
|
||||
Reference in New Issue
Block a user