Added support for running different models

This commit is contained in:
N00MKRAD
2021-01-14 00:03:01 +01:00
parent 0fe0b89fa5
commit 9aa0b14f3c
4 changed files with 20 additions and 19 deletions

View File

@@ -55,7 +55,7 @@ namespace Flowframes
await PostProcessFrames();
if (canceled) return;
Program.mainForm.SetStatus("Running AI...");
await RunAi(current.interpFolder, current.ai);
await RunAi(current.interpFolder, current.ai, current.model);
if (canceled) return;
Program.mainForm.SetProgress(100);
if(!currentlyUsingAutoEnc)
@@ -140,7 +140,7 @@ namespace Flowframes
AiProcess.filenameMap = IOUtils.RenameCounterDirReversible(current.framesFolder, "png", 1, 8);
}
public static async Task RunAi(string outpath, AI ai, bool stepByStep = false)
public static async Task RunAi(string outpath, AI ai, string model, bool stepByStep = false)
{
currentlyUsingAutoEnc = Utils.UseAutoEnc(stepByStep, current);
@@ -149,10 +149,10 @@ namespace Flowframes
List<Task> tasks = new List<Task>();
if (ai.aiName == Networks.rifeCuda.aiName)
tasks.Add(AiProcess.RunRifeCuda(current.framesFolder, current.interpFactor));
tasks.Add(AiProcess.RunRifeCuda(current.framesFolder, current.interpFactor, current.model));
if (ai.aiName == Networks.rifeNcnn.aiName)
tasks.Add(AiProcess.RunRifeNcnnMulti(current.framesFolder, outpath, current.interpFactor));
tasks.Add(AiProcess.RunRifeNcnnMulti(current.framesFolder, outpath, current.interpFactor, current.model));
if (currentlyUsingAutoEnc)
{

View File

@@ -119,7 +119,7 @@ namespace Flowframes.Main
int targetFrameCount = frames * current.interpFactor;
if (canceled) return;
Program.mainForm.SetStatus("Running AI...");
await RunAi(current.interpFolder, current.ai, true);
await RunAi(current.interpFolder, current.ai, current.model, true);
Program.mainForm.SetProgress(0);
}

View File

@@ -68,14 +68,14 @@ namespace Flowframes
}
}
public static async Task RunRifeCuda(string framesPath, int interpFactor)
public static async Task RunRifeCuda(string framesPath, int interpFactor, string mdl)
{
InterpolateUtils.GetProgressByFrameAmount(Interpolate.current.interpFolder, Interpolate.current.GetTargetFrameCount(framesPath, interpFactor));
string rifeDir = Path.Combine(Paths.GetPkgPath(), Path.GetFileNameWithoutExtension(Packages.rifeCuda.fileName));
string script = "inference_video.py";
string uhdStr = InterpolateUtils.UseUHD() ? "--UHD" : "";
string args = $" --input {framesPath.Wrap()} --exp {(int)Math.Log(interpFactor, 2)} {uhdStr} --imgformat {InterpolateUtils.GetOutExt()} --output {Paths.interpDir}";
string args = $" --input {framesPath.Wrap()} --model {mdl} --exp {(int)Math.Log(interpFactor, 2)} {uhdStr} --imgformat {InterpolateUtils.GetOutExt()} --output {Paths.interpDir}";
if (!File.Exists(Path.Combine(rifeDir, script)))
{
@@ -104,18 +104,18 @@ namespace Flowframes
AiFinished("RIFE");
}
public static async Task RunRifeNcnnMulti(string framesPath, string outPath, int times)
public static async Task RunRifeNcnnMulti(string framesPath, string outPath, int factor, string mdl)
{
processTimeMulti.Restart();
Logger.Log($"Running RIFE{(InterpolateUtils.UseUHD() ? " (UHD Mode)" : "")}...", false);
bool useAutoEnc = Interpolate.currentlyUsingAutoEnc;
if(times > 2)
if(factor > 2)
AutoEncode.paused = true; // Disable autoenc until the last iteration
await RunRifePartial(framesPath, outPath);
await RunRifePartial(framesPath, outPath, mdl);
if (times == 4 || times == 8) // #2
if (factor == 4 || factor == 8) // #2
{
if (Interpolate.canceled) return;
Logger.Log("Re-Running RIFE for 4x interpolation...", false);
@@ -123,13 +123,13 @@ namespace Flowframes
IOUtils.TryDeleteIfExists(run1ResultsPath);
Directory.Move(outPath, run1ResultsPath);
Directory.CreateDirectory(outPath);
if (useAutoEnc && times == 4)
if (useAutoEnc && factor == 4)
AutoEncode.paused = false;
await RunRifePartial(run1ResultsPath, outPath);
await RunRifePartial(run1ResultsPath, outPath, mdl);
IOUtils.TryDeleteIfExists(run1ResultsPath);
}
if (times == 8) // #3
if (factor == 8) // #3
{
if (Interpolate.canceled) return;
Logger.Log("Re-Running RIFE for 8x interpolation...", false);
@@ -137,9 +137,9 @@ namespace Flowframes
IOUtils.TryDeleteIfExists(run2ResultsPath);
Directory.Move(outPath, run2ResultsPath);
Directory.CreateDirectory(outPath);
if (useAutoEnc && times == 8)
if (useAutoEnc && factor == 8)
AutoEncode.paused = false;
await RunRifePartial(run2ResultsPath, outPath);
await RunRifePartial(run2ResultsPath, outPath, mdl);
IOUtils.TryDeleteIfExists(run2ResultsPath);
}
@@ -151,7 +151,7 @@ namespace Flowframes
AiFinished("RIFE");
}
static async Task RunRifePartial(string inPath, string outPath)
static async Task RunRifePartial(string inPath, string outPath, string mdl)
{
InterpolateUtils.GetProgressByFrameAmount(Interpolate.current.interpFolder, Interpolate.current.GetTargetFrameCount(inPath, 2));
@@ -161,7 +161,7 @@ namespace Flowframes
string uhdStr = InterpolateUtils.UseUHD() ? "-u" : "";
rifeNcnn.StartInfo.Arguments = $"{OSUtils.GetCmdArg()} cd /D {PkgUtils.GetPkgFolder(Packages.rifeNcnn).Wrap()} & rife-ncnn-vulkan.exe " +
$" -v -i {inPath.Wrap()} -o {outPath.Wrap()} -m rife1.7 {uhdStr} -g {Config.Get("ncnnGpus")} -f {InterpolateUtils.GetOutExt()} -j {GetNcnnThreads()}".TrimWhitespacesSafe();
$" -v -i {inPath.Wrap()} -o {outPath.Wrap()} -m {mdl.ToLower()} {uhdStr} -g {Config.Get("ncnnGpus")} -f {InterpolateUtils.GetOutExt()} -j {GetNcnnThreads()}".TrimWhitespacesSafe();
Logger.Log("cmd.exe " + rifeNcnn.StartInfo.Arguments, true);

View File

@@ -37,6 +37,7 @@ except:
parser = argparse.ArgumentParser(description='Interpolation for a pair of images')
parser.add_argument('--input', dest='input', type=str, default=None)
parser.add_argument('--output', required=False, default='frames-interpolated')
parser.add_argument('--model', required=False, default='models')
parser.add_argument('--imgformat', default="png")
parser.add_argument('--UHD', dest='UHD', action='store_true', help='support 4k video')
parser.add_argument('--exp', dest='exp', type=int, default=1)
@@ -45,7 +46,7 @@ assert (not args.input is None)
from model.RIFE_HD import Model
model = Model()
model.load_model(os.path.join(dname, "models"), -1)
model.load_model(os.path.join(dname, args.model), -1)
model.eval()
model.device()