AI Class/Constructor revamp

This commit is contained in:
n00mkrad
2022-07-21 10:08:53 +02:00
parent 9be08deb4c
commit 760a187dfe
17 changed files with 68 additions and 54 deletions

View File

@@ -11,11 +11,11 @@ namespace Flowframes.Data
{
public enum AiBackend { Pytorch, Ncnn, Tensorflow, Other }
public AiBackend Backend { get; set; } = AiBackend.Pytorch;
public string AiName { get; set; } = "";
public string AiNameShort { get; set; } = "";
public string FriendlyName { get; set; } = "";
public string Description { get; set; } = "";
public string PkgDir { get; set; } = "";
public string NameInternal { get; set; } = "";
public string NameShort { get { return NameInternal.Split(' ')[0].Split('_')[0]; } }
public string FriendlyName { get { return $"{NameShort} ({GetFrameworkString()})"; } }
public string Description { get { return $"{GetImplemString()} of {NameShort}{(Backend == AiBackend.Pytorch ? " (Nvidia Only!)" : "")}"; } }
public string PkgDir { get { return NameInternal.Replace("_", "-").ToLower(); } }
public enum InterpFactorSupport { Fixed, AnyPowerOfTwo, AnyInteger, AnyFloat }
public InterpFactorSupport FactorSupport { get; set; } = InterpFactorSupport.Fixed;
public int[] SupportedFactors { get; set; } = new int[0];
@@ -23,19 +23,40 @@ namespace Flowframes.Data
public string LogFilename { get { return PkgDir + "-log"; } }
public AI(AiBackend backend, string aiName, string friendlyName, string desc, string pkgDir, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null)
public AI(AiBackend backend, string aiName, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null)
{
Backend = backend;
AiName = aiName;
AiNameShort = aiName.Split(' ')[0].Split('_')[0];
FriendlyName = friendlyName;
Description = desc;
PkgDir = pkgDir;
NameInternal = aiName;
SupportedFactors = supportedFactors;
FactorSupport = factorSupport;
}
if (backend == AiBackend.Pytorch)
Description += " (Nvidia Only!)";
private string GetImplemString ()
{
if (Backend == AiBackend.Pytorch)
return $"CUDA/Pytorch Implementation";
if(Backend == AiBackend.Ncnn)
return $"Vulkan/NCNN{(Piped ? "/VapourSynth" : "")} Implementation";
if (Backend == AiBackend.Tensorflow)
return $"Tensorflow Implementation";
return "";
}
private string GetFrameworkString()
{
if (Backend == AiBackend.Pytorch)
return $"CUDA";
if (Backend == AiBackend.Ncnn)
return $"NCNN{(Piped ? "/VS" : "")}";
if (Backend == AiBackend.Tensorflow)
return $"TF";
return "Custom";
}
}
}

View File

@@ -6,27 +6,20 @@ namespace Flowframes.Data
{
class Implementations
{
public static AI rifeCuda = new AI(AI.AiBackend.Pytorch, "RIFE_CUDA", "RIFE",
"CUDA/Pytorch Implementation of RIFE", "rife-cuda", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI rifeCuda = new AI(AI.AiBackend.Pytorch, "RIFE_CUDA", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI rifeNcnnVs = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN_VS", "RIFE (NCNN/VS)",
"Vulkan/NCNN/VapourSynth Implementation of RIFE", "rife-ncnn-vs", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 })
public static AI rifeNcnnVs = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN_VS", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 })
{ Piped = true };
public static AI rifeNcnn = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN", "RIFE (NCNN)",
"Vulkan/NCNN Implementation of RIFE", "rife-ncnn", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI rifeNcnn = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI flavrCuda = new AI(AI.AiBackend.Pytorch, "FLAVR_CUDA", "FLAVR",
"Experimental Pytorch Implementation of FLAVR", "flavr-cuda", AI.InterpFactorSupport.Fixed, new int[] { 2, 4, 8 });
public static AI flavrCuda = new AI(AI.AiBackend.Pytorch, "FLAVR_CUDA", AI.InterpFactorSupport.Fixed, new int[] { 2, 4, 8 });
public static AI dainNcnn = new AI(AI.AiBackend.Ncnn, "DAIN_NCNN", "DAIN (NCNN)",
"Vulkan/NCNN Implementation of DAIN", "dain-ncnn", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8 });
public static AI dainNcnn = new AI(AI.AiBackend.Ncnn, "DAIN_NCNN", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8 });
public static AI xvfiCuda = new AI(AI.AiBackend.Pytorch, "XVFI_CUDA", "XVFI",
"CUDA/Pytorch Implementation of XVFI", "xvfi-cuda", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI xvfiCuda = new AI(AI.AiBackend.Pytorch, "XVFI_CUDA", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 });
public static AI ifrnetNcnn = new AI(AI.AiBackend.Ncnn, "IFRNet_NCNN", "IFRNet (NCNN)",
"Vulkan/NCNN Implementation of IFRNet", "ifrnet-ncnn", AI.InterpFactorSupport.Fixed, new int[] { 2 });
public static AI ifrnetNcnn = new AI(AI.AiBackend.Ncnn, "IFRNet_NCNN", AI.InterpFactorSupport.Fixed, new int[] { 2 });
public static List<AI> NetworksAll
{
@@ -53,7 +46,7 @@ namespace Flowframes.Data
{
foreach (AI ai in NetworksAll)
{
if (ai.AiName == aiName)
if (ai.NameInternal == aiName)
return ai;
}

View File

@@ -234,7 +234,7 @@ namespace Flowframes
{
string s = $"INPATH|{inPath}\n";
s += $"OUTPATH|{outPath}\n";
s += $"AI|{ai.AiName}\n";
s += $"AI|{ai.NameInternal}\n";
s += $"INFPSDETECTED|{inFpsDetected}\n";
s += $"INFPS|{inFps}\n";
s += $"OUTFPS|{outFps}\n";

View File

@@ -192,7 +192,7 @@ namespace Flowframes
inputTbox.Text = entry.inPath;
MainUiFunctions.SetOutPath(outputTbox, entry.outPath);
interpFactorCombox.Text = entry.interpFactor.ToString();
aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.AiName == entry.ai.AiName).FirstOrDefault());
aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.NameInternal == entry.ai.NameInternal).FirstOrDefault());
SetOutMode(entry.outMode);
}
@@ -276,9 +276,9 @@ namespace Flowframes
aiCombox.Items.Add(GetAiComboboxName(ai));
string lastUsedAiName = Config.Get(Config.Key.lastUsedAiName);
aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.AiName == lastUsedAiName).FirstOrDefault());
aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.NameInternal == lastUsedAiName).FirstOrDefault());
if (aiCombox.SelectedIndex < 0) aiCombox.SelectedIndex = 0;
Config.Set(Config.Key.lastUsedAiName, GetAi().AiName);
Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal);
ConfigParser.LoadComboxIndex(outModeCombox);
}
@@ -475,7 +475,7 @@ namespace Flowframes
interpFactorCombox.SelectedIndex = 0;
if (initialized)
Config.Set(Config.Key.lastUsedAiName, GetAi().AiName);
Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal);
interpFactorCombox_SelectedIndexChanged(null, null);
fpsOutTbox.ReadOnly = GetAi().FactorSupport != AI.InterpFactorSupport.AnyFloat;

View File

@@ -34,7 +34,7 @@ namespace Flowframes.Forms
InterpSettings entry = Program.batchQueue.ElementAt(i);
string niceOutMode = entry.outMode.ToString().ToUpper().Remove("VID").Remove("IMG");
string str = $"#{i+1}: {Path.GetFileName(entry.inPath).Trunc(40)} - {entry.inFps.GetFloat()} FPS => " +
$"{entry.interpFactor}x {entry.ai.AiNameShort} ({entry.model.name}) => {niceOutMode}";
$"{entry.interpFactor}x {entry.ai.NameShort} ({entry.model.name}) => {niceOutMode}";
taskList.Items.Add(str);
}
}

View File

@@ -578,7 +578,7 @@ namespace Flowframes.IO
filename = filename.Replace("[NAME]", inName);
filename = filename.Replace("[FULLNAME]", Path.GetFileName(curr.inPath));
filename = filename.Replace("[FACTOR]", curr.interpFactor.ToStringDot());
filename = filename.Replace("[AI]", curr.ai.AiNameShort.ToUpper());
filename = filename.Replace("[AI]", curr.ai.NameShort.ToUpper());
filename = filename.Replace("[MODEL]", curr.model.name.Remove(" "));
filename = filename.Replace("[FPS]", fps.ToStringDot());
filename = filename.Replace("[ROUNDFPS]", fps.RoundToInt().ToString());

View File

@@ -135,7 +135,7 @@ namespace Flowframes.IO
public static async Task DownloadModelFiles (AI ai, string modelDir, bool log = true)
{
string aiDir = ai.PkgDir;
Logger.Log($"DownloadModelFiles(string ai = {ai.AiName}, string model = {modelDir}, bool log = {log})", true);
Logger.Log($"DownloadModelFiles(string ai = {ai.NameInternal}, string model = {modelDir}, bool log = {log})", true);
try
{

View File

@@ -19,7 +19,7 @@ namespace Flowframes.Main
if (!File.Exists(modelsFile))
{
Logger.Log($"Error: File models.json is missing for {ai.AiName}, can't load AI models for this implementation!");
Logger.Log($"Error: File models.json is missing for {ai.NameInternal}, can't load AI models for this implementation!");
return new ModelCollection(ai);
}

View File

@@ -82,7 +82,7 @@ namespace Flowframes.Main
string fname = Path.GetFileName(entry.inPath);
if (IoUtils.IsPathDirectory(entry.inPath)) fname = Path.GetDirectoryName(entry.inPath);
Logger.Log($"Queue: Processing {fname} ({entry.interpFactor}x {entry.ai.AiNameShort}).");
Logger.Log($"Queue: Processing {fname} ({entry.interpFactor}x {entry.ai.NameShort}).");
MediaFile mf = new MediaFile(entry.inPath);
await mf.Initialize();
@@ -98,7 +98,7 @@ namespace Flowframes.Main
Program.batchQueue.Dequeue();
Program.mainForm.SetWorking(false);
Logger.Log($"Queue: Done processing {fname} ({entry.interpFactor}x {entry.ai.AiNameShort}).");
Logger.Log($"Queue: Done processing {fname} ({entry.interpFactor}x {entry.ai.NameShort}).");
}
static void SetBusy(bool state)

View File

@@ -196,25 +196,25 @@ namespace Flowframes
List<Task> tasks = new List<Task>();
if (ai.AiName == Implementations.rifeCuda.AiName)
if (ai.NameInternal == Implementations.rifeCuda.NameInternal)
tasks.Add(AiProcess.RunRifeCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir));
if (ai.AiName == Implementations.rifeNcnn.AiName)
if (ai.NameInternal == Implementations.rifeNcnn.NameInternal)
tasks.Add(AiProcess.RunRifeNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir));
if (ai.AiName == Implementations.rifeNcnnVs.AiName)
if (ai.NameInternal == Implementations.rifeNcnnVs.NameInternal)
tasks.Add(AiProcess.RunRifeNcnnVs(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir));
if (ai.AiName == Implementations.flavrCuda.AiName)
if (ai.NameInternal == Implementations.flavrCuda.NameInternal)
tasks.Add(AiProcess.RunFlavrCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir));
if (ai.AiName == Implementations.dainNcnn.AiName)
if (ai.NameInternal == Implementations.dainNcnn.NameInternal)
tasks.Add(AiProcess.RunDainNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir, Config.GetInt(Config.Key.dainNcnnTilesize, 512)));
if (ai.AiName == Implementations.xvfiCuda.AiName)
if (ai.NameInternal == Implementations.xvfiCuda.NameInternal)
tasks.Add(AiProcess.RunXvfiCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir));
if(ai.AiName == Implementations.ifrnetNcnn.AiName)
if(ai.NameInternal == Implementations.ifrnetNcnn.NameInternal)
tasks.Add(AiProcess.RunIfrnetNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir));
if (currentlyUsingAutoEnc)

View File

@@ -161,7 +161,7 @@ namespace Flowframes.Main
return false;
}
if (I.currentSettings.ai.AiName.ToUpper().Contains("CUDA") && NvApi.gpuList.Count < 1)
if (I.currentSettings.ai.NameInternal.ToUpper().Contains("CUDA") && NvApi.gpuList.Count < 1)
{
UiUtils.ShowMessageBox("Warning: No Nvidia GPU was detected. CUDA might fall back to CPU!\n\nTry an NCNN implementation instead if you don't have an Nvidia GPU.", UiUtils.MessageType.Error);

View File

@@ -25,7 +25,7 @@ namespace Flowframes
public static int GetPadding ()
{
return (Interpolate.currentSettings.ai.AiName == Implementations.flavrCuda.AiName) ? 8 : 2; // FLAVR input needs to be divisible by 8
return (Interpolate.currentSettings.ai.NameInternal == Implementations.flavrCuda.NameInternal) ? 8 : 2; // FLAVR input needs to be divisible by 8
}
public static string GetPadFilter ()

View File

@@ -54,7 +54,7 @@ namespace Flowframes.MiscUtils
return;
ModelCollection.ModelInfo modelInfo = modelCollection.models[i];
form.SetStatus($"Downloading files for {modelInfo.ai.AiName.Replace("_", "-")}...");
form.SetStatus($"Downloading files for {modelInfo.ai.NameInternal.Replace("_", "-")}...");
await ModelDownloader.DownloadModelFiles(ai, modelInfo.dir, false);
taskCounter++;
UpdateProgressBar();

View File

@@ -617,7 +617,7 @@ namespace Flowframes.Os
if (!hasShownError && line.ToLower().Contains("error(s) in loading state_dict"))
{
hasShownError = true;
string msg = (Interpolate.currentSettings.ai.AiName == Implementations.flavrCuda.AiName) ? "\n\nFor FLAVR, you need to select the correct model for each scale!" : "";
string msg = (Interpolate.currentSettings.ai.NameInternal == Implementations.flavrCuda.NameInternal) ? "\n\nFor FLAVR, you need to select the correct model for each scale!" : "";
UiUtils.ShowMessageBox($"Error loading the AI model!\n\n{line}{msg}", UiUtils.MessageType.Error);
}
@@ -645,7 +645,7 @@ namespace Flowframes.Os
if (!hasShownError && err && line.Contains("vkAllocateMemory failed"))
{
hasShownError = true;
bool usingDain = (Interpolate.currentSettings.ai.AiName == Implementations.dainNcnn.AiName);
bool usingDain = (Interpolate.currentSettings.ai.NameInternal == Implementations.dainNcnn.NameInternal);
string msg = usingDain ? "\n\nTry reducing the tile size in the AI settings." : "\n\nTry a lower resolution (Settings -> Max Video Size).";
UiUtils.ShowMessageBox($"Vulkan ran out of memory!\n\n{line}{msg}", UiUtils.MessageType.Error);
}

View File

@@ -79,7 +79,7 @@ namespace Flowframes.Ui
{
try
{
string ncnnStr = I.currentSettings.ai.AiName.Contains("NCNN") ? " done" : "";
string ncnnStr = I.currentSettings.ai.NameInternal.Contains("NCNN") ? " done" : "";
Regex frameRegex = new Regex($@"(?<=.)\d*(?={I.currentSettings.interpExt}{ncnnStr})");
if (!frameRegex.IsMatch(output)) return;
lastFrame = Math.Max(int.Parse(frameRegex.Match(output).Value), lastFrame);

View File

@@ -159,7 +159,7 @@ namespace Flowframes.Ui
{
AI ai = Program.mainForm.GetAi();
if (ai.AiName == Implementations.rifeNcnn.AiName && !Program.mainForm.GetModel(ai).dir.Contains("v4"))
if (ai.NameInternal == Implementations.rifeNcnn.NameInternal && !Program.mainForm.GetModel(ai).dir.Contains("v4"))
{
if (factor != 2)
Logger.Log($"{ai.FriendlyName} models before 4.0 only support 2x interpolation!");

View File

@@ -60,7 +60,7 @@ namespace Flowframes.Ui
}
catch (Exception e)
{
Logger.Log($"Failed to load available AI models for {ai.AiName}! {e.Message}");
Logger.Log($"Failed to load available AI models for {ai.NameInternal}! {e.Message}");
Logger.Log($"Stack Trace: {e.StackTrace}", true);
}