From 760a187dfece709eb5ff3d91678ca51329cc0580 Mon Sep 17 00:00:00 2001 From: n00mkrad Date: Thu, 21 Jul 2022 10:08:53 +0200 Subject: [PATCH] AI Class/Constructor revamp --- Code/Data/AI.cs | 47 +++++++++++++++++------- Code/Data/Implementations.cs | 23 ++++-------- Code/Data/InterpSettings.cs | 2 +- Code/Form1.cs | 8 ++-- Code/Forms/BatchForm.cs | 2 +- Code/IO/IoUtils.cs | 2 +- Code/IO/ModelDownloader.cs | 2 +- Code/Main/AiModels.cs | 2 +- Code/Main/BatchProcessing.cs | 4 +- Code/Main/Interpolate.cs | 14 +++---- Code/Main/InterpolateUtils.cs | 2 +- Code/Media/FfmpegCommands.cs | 2 +- Code/MiscUtils/ModelDownloadFormUtils.cs | 2 +- Code/Os/AiProcess.cs | 4 +- Code/Ui/InterpolationProgress.cs | 2 +- Code/Ui/MainUiFunctions.cs | 2 +- Code/Ui/UiUtils.cs | 2 +- 17 files changed, 68 insertions(+), 54 deletions(-) diff --git a/Code/Data/AI.cs b/Code/Data/AI.cs index cdfcf51..6155e33 100644 --- a/Code/Data/AI.cs +++ b/Code/Data/AI.cs @@ -11,11 +11,11 @@ namespace Flowframes.Data { public enum AiBackend { Pytorch, Ncnn, Tensorflow, Other } public AiBackend Backend { get; set; } = AiBackend.Pytorch; - public string AiName { get; set; } = ""; - public string AiNameShort { get; set; } = ""; - public string FriendlyName { get; set; } = ""; - public string Description { get; set; } = ""; - public string PkgDir { get; set; } = ""; + public string NameInternal { get; set; } = ""; + public string NameShort { get { return NameInternal.Split(' ')[0].Split('_')[0]; } } + public string FriendlyName { get { return $"{NameShort} ({GetFrameworkString()})"; } } + public string Description { get { return $"{GetImplemString()} of {NameShort}{(Backend == AiBackend.Pytorch ? " (Nvidia Only!)" : "")}"; } } + public string PkgDir { get { return NameInternal.Replace("_", "-").ToLower(); } } public enum InterpFactorSupport { Fixed, AnyPowerOfTwo, AnyInteger, AnyFloat } public InterpFactorSupport FactorSupport { get; set; } = InterpFactorSupport.Fixed; public int[] SupportedFactors { get; set; } = new int[0]; @@ -23,19 +23,40 @@ namespace Flowframes.Data public string LogFilename { get { return PkgDir + "-log"; } } - public AI(AiBackend backend, string aiName, string friendlyName, string desc, string pkgDir, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null) + public AI(AiBackend backend, string aiName, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null) { Backend = backend; - AiName = aiName; - AiNameShort = aiName.Split(' ')[0].Split('_')[0]; - FriendlyName = friendlyName; - Description = desc; - PkgDir = pkgDir; + NameInternal = aiName; SupportedFactors = supportedFactors; FactorSupport = factorSupport; + } - if (backend == AiBackend.Pytorch) - Description += " (Nvidia Only!)"; + private string GetImplemString () + { + if (Backend == AiBackend.Pytorch) + return $"CUDA/Pytorch Implementation"; + + if(Backend == AiBackend.Ncnn) + return $"Vulkan/NCNN{(Piped ? "/VapourSynth" : "")} Implementation"; + + if (Backend == AiBackend.Tensorflow) + return $"Tensorflow Implementation"; + + return ""; + } + + private string GetFrameworkString() + { + if (Backend == AiBackend.Pytorch) + return $"CUDA"; + + if (Backend == AiBackend.Ncnn) + return $"NCNN{(Piped ? "/VS" : "")}"; + + if (Backend == AiBackend.Tensorflow) + return $"TF"; + + return "Custom"; } } } diff --git a/Code/Data/Implementations.cs b/Code/Data/Implementations.cs index 3897db1..6dbc360 100644 --- a/Code/Data/Implementations.cs +++ b/Code/Data/Implementations.cs @@ -6,27 +6,20 @@ namespace Flowframes.Data { class Implementations { - public static AI rifeCuda = new AI(AI.AiBackend.Pytorch, "RIFE_CUDA", "RIFE", - "CUDA/Pytorch Implementation of RIFE", "rife-cuda", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + public static AI rifeCuda = new AI(AI.AiBackend.Pytorch, "RIFE_CUDA", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); - public static AI rifeNcnnVs = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN_VS", "RIFE (NCNN/VS)", - "Vulkan/NCNN/VapourSynth Implementation of RIFE", "rife-ncnn-vs", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }) + public static AI rifeNcnnVs = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN_VS", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }) { Piped = true }; - public static AI rifeNcnn = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN", "RIFE (NCNN)", - "Vulkan/NCNN Implementation of RIFE", "rife-ncnn", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + public static AI rifeNcnn = new AI(AI.AiBackend.Ncnn, "RIFE_NCNN", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); - public static AI flavrCuda = new AI(AI.AiBackend.Pytorch, "FLAVR_CUDA", "FLAVR", - "Experimental Pytorch Implementation of FLAVR", "flavr-cuda", AI.InterpFactorSupport.Fixed, new int[] { 2, 4, 8 }); + public static AI flavrCuda = new AI(AI.AiBackend.Pytorch, "FLAVR_CUDA", AI.InterpFactorSupport.Fixed, new int[] { 2, 4, 8 }); - public static AI dainNcnn = new AI(AI.AiBackend.Ncnn, "DAIN_NCNN", "DAIN (NCNN)", - "Vulkan/NCNN Implementation of DAIN", "dain-ncnn", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8 }); + public static AI dainNcnn = new AI(AI.AiBackend.Ncnn, "DAIN_NCNN", AI.InterpFactorSupport.AnyFloat, new int[] { 2, 3, 4, 5, 6, 7, 8 }); - public static AI xvfiCuda = new AI(AI.AiBackend.Pytorch, "XVFI_CUDA", "XVFI", - "CUDA/Pytorch Implementation of XVFI", "xvfi-cuda", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + public static AI xvfiCuda = new AI(AI.AiBackend.Pytorch, "XVFI_CUDA", AI.InterpFactorSupport.AnyInteger, new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }); - public static AI ifrnetNcnn = new AI(AI.AiBackend.Ncnn, "IFRNet_NCNN", "IFRNet (NCNN)", - "Vulkan/NCNN Implementation of IFRNet", "ifrnet-ncnn", AI.InterpFactorSupport.Fixed, new int[] { 2 }); + public static AI ifrnetNcnn = new AI(AI.AiBackend.Ncnn, "IFRNet_NCNN", AI.InterpFactorSupport.Fixed, new int[] { 2 }); public static List NetworksAll { @@ -53,7 +46,7 @@ namespace Flowframes.Data { foreach (AI ai in NetworksAll) { - if (ai.AiName == aiName) + if (ai.NameInternal == aiName) return ai; } diff --git a/Code/Data/InterpSettings.cs b/Code/Data/InterpSettings.cs index aef5e54..01c143c 100644 --- a/Code/Data/InterpSettings.cs +++ b/Code/Data/InterpSettings.cs @@ -234,7 +234,7 @@ namespace Flowframes { string s = $"INPATH|{inPath}\n"; s += $"OUTPATH|{outPath}\n"; - s += $"AI|{ai.AiName}\n"; + s += $"AI|{ai.NameInternal}\n"; s += $"INFPSDETECTED|{inFpsDetected}\n"; s += $"INFPS|{inFps}\n"; s += $"OUTFPS|{outFps}\n"; diff --git a/Code/Form1.cs b/Code/Form1.cs index 98ef616..628956d 100644 --- a/Code/Form1.cs +++ b/Code/Form1.cs @@ -192,7 +192,7 @@ namespace Flowframes inputTbox.Text = entry.inPath; MainUiFunctions.SetOutPath(outputTbox, entry.outPath); interpFactorCombox.Text = entry.interpFactor.ToString(); - aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.AiName == entry.ai.AiName).FirstOrDefault()); + aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.NameInternal == entry.ai.NameInternal).FirstOrDefault()); SetOutMode(entry.outMode); } @@ -276,9 +276,9 @@ namespace Flowframes aiCombox.Items.Add(GetAiComboboxName(ai)); string lastUsedAiName = Config.Get(Config.Key.lastUsedAiName); - aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.AiName == lastUsedAiName).FirstOrDefault()); + aiCombox.SelectedIndex = Implementations.NetworksAvailable.IndexOf(Implementations.NetworksAvailable.Where(x => x.NameInternal == lastUsedAiName).FirstOrDefault()); if (aiCombox.SelectedIndex < 0) aiCombox.SelectedIndex = 0; - Config.Set(Config.Key.lastUsedAiName, GetAi().AiName); + Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal); ConfigParser.LoadComboxIndex(outModeCombox); } @@ -475,7 +475,7 @@ namespace Flowframes interpFactorCombox.SelectedIndex = 0; if (initialized) - Config.Set(Config.Key.lastUsedAiName, GetAi().AiName); + Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal); interpFactorCombox_SelectedIndexChanged(null, null); fpsOutTbox.ReadOnly = GetAi().FactorSupport != AI.InterpFactorSupport.AnyFloat; diff --git a/Code/Forms/BatchForm.cs b/Code/Forms/BatchForm.cs index 6669562..7dbc6a1 100644 --- a/Code/Forms/BatchForm.cs +++ b/Code/Forms/BatchForm.cs @@ -34,7 +34,7 @@ namespace Flowframes.Forms InterpSettings entry = Program.batchQueue.ElementAt(i); string niceOutMode = entry.outMode.ToString().ToUpper().Remove("VID").Remove("IMG"); string str = $"#{i+1}: {Path.GetFileName(entry.inPath).Trunc(40)} - {entry.inFps.GetFloat()} FPS => " + - $"{entry.interpFactor}x {entry.ai.AiNameShort} ({entry.model.name}) => {niceOutMode}"; + $"{entry.interpFactor}x {entry.ai.NameShort} ({entry.model.name}) => {niceOutMode}"; taskList.Items.Add(str); } } diff --git a/Code/IO/IoUtils.cs b/Code/IO/IoUtils.cs index 8a74e73..88be981 100644 --- a/Code/IO/IoUtils.cs +++ b/Code/IO/IoUtils.cs @@ -578,7 +578,7 @@ namespace Flowframes.IO filename = filename.Replace("[NAME]", inName); filename = filename.Replace("[FULLNAME]", Path.GetFileName(curr.inPath)); filename = filename.Replace("[FACTOR]", curr.interpFactor.ToStringDot()); - filename = filename.Replace("[AI]", curr.ai.AiNameShort.ToUpper()); + filename = filename.Replace("[AI]", curr.ai.NameShort.ToUpper()); filename = filename.Replace("[MODEL]", curr.model.name.Remove(" ")); filename = filename.Replace("[FPS]", fps.ToStringDot()); filename = filename.Replace("[ROUNDFPS]", fps.RoundToInt().ToString()); diff --git a/Code/IO/ModelDownloader.cs b/Code/IO/ModelDownloader.cs index eb2e6a4..27f00c2 100644 --- a/Code/IO/ModelDownloader.cs +++ b/Code/IO/ModelDownloader.cs @@ -135,7 +135,7 @@ namespace Flowframes.IO public static async Task DownloadModelFiles (AI ai, string modelDir, bool log = true) { string aiDir = ai.PkgDir; - Logger.Log($"DownloadModelFiles(string ai = {ai.AiName}, string model = {modelDir}, bool log = {log})", true); + Logger.Log($"DownloadModelFiles(string ai = {ai.NameInternal}, string model = {modelDir}, bool log = {log})", true); try { diff --git a/Code/Main/AiModels.cs b/Code/Main/AiModels.cs index 84025e9..e562bfa 100644 --- a/Code/Main/AiModels.cs +++ b/Code/Main/AiModels.cs @@ -19,7 +19,7 @@ namespace Flowframes.Main if (!File.Exists(modelsFile)) { - Logger.Log($"Error: File models.json is missing for {ai.AiName}, can't load AI models for this implementation!"); + Logger.Log($"Error: File models.json is missing for {ai.NameInternal}, can't load AI models for this implementation!"); return new ModelCollection(ai); } diff --git a/Code/Main/BatchProcessing.cs b/Code/Main/BatchProcessing.cs index 7ec5c44..3ce9242 100644 --- a/Code/Main/BatchProcessing.cs +++ b/Code/Main/BatchProcessing.cs @@ -82,7 +82,7 @@ namespace Flowframes.Main string fname = Path.GetFileName(entry.inPath); if (IoUtils.IsPathDirectory(entry.inPath)) fname = Path.GetDirectoryName(entry.inPath); - Logger.Log($"Queue: Processing {fname} ({entry.interpFactor}x {entry.ai.AiNameShort})."); + Logger.Log($"Queue: Processing {fname} ({entry.interpFactor}x {entry.ai.NameShort})."); MediaFile mf = new MediaFile(entry.inPath); await mf.Initialize(); @@ -98,7 +98,7 @@ namespace Flowframes.Main Program.batchQueue.Dequeue(); Program.mainForm.SetWorking(false); - Logger.Log($"Queue: Done processing {fname} ({entry.interpFactor}x {entry.ai.AiNameShort})."); + Logger.Log($"Queue: Done processing {fname} ({entry.interpFactor}x {entry.ai.NameShort})."); } static void SetBusy(bool state) diff --git a/Code/Main/Interpolate.cs b/Code/Main/Interpolate.cs index 2b012ae..07cceb0 100644 --- a/Code/Main/Interpolate.cs +++ b/Code/Main/Interpolate.cs @@ -196,25 +196,25 @@ namespace Flowframes List tasks = new List(); - if (ai.AiName == Implementations.rifeCuda.AiName) + if (ai.NameInternal == Implementations.rifeCuda.NameInternal) tasks.Add(AiProcess.RunRifeCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir)); - if (ai.AiName == Implementations.rifeNcnn.AiName) + if (ai.NameInternal == Implementations.rifeNcnn.NameInternal) tasks.Add(AiProcess.RunRifeNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir)); - if (ai.AiName == Implementations.rifeNcnnVs.AiName) + if (ai.NameInternal == Implementations.rifeNcnnVs.NameInternal) tasks.Add(AiProcess.RunRifeNcnnVs(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir)); - if (ai.AiName == Implementations.flavrCuda.AiName) + if (ai.NameInternal == Implementations.flavrCuda.NameInternal) tasks.Add(AiProcess.RunFlavrCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir)); - if (ai.AiName == Implementations.dainNcnn.AiName) + if (ai.NameInternal == Implementations.dainNcnn.NameInternal) tasks.Add(AiProcess.RunDainNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir, Config.GetInt(Config.Key.dainNcnnTilesize, 512))); - if (ai.AiName == Implementations.xvfiCuda.AiName) + if (ai.NameInternal == Implementations.xvfiCuda.NameInternal) tasks.Add(AiProcess.RunXvfiCuda(currentSettings.framesFolder, currentSettings.interpFactor, currentSettings.model.dir)); - if(ai.AiName == Implementations.ifrnetNcnn.AiName) + if(ai.NameInternal == Implementations.ifrnetNcnn.NameInternal) tasks.Add(AiProcess.RunIfrnetNcnn(currentSettings.framesFolder, outpath, currentSettings.interpFactor, currentSettings.model.dir)); if (currentlyUsingAutoEnc) diff --git a/Code/Main/InterpolateUtils.cs b/Code/Main/InterpolateUtils.cs index 928d181..7d773d6 100644 --- a/Code/Main/InterpolateUtils.cs +++ b/Code/Main/InterpolateUtils.cs @@ -161,7 +161,7 @@ namespace Flowframes.Main return false; } - if (I.currentSettings.ai.AiName.ToUpper().Contains("CUDA") && NvApi.gpuList.Count < 1) + if (I.currentSettings.ai.NameInternal.ToUpper().Contains("CUDA") && NvApi.gpuList.Count < 1) { UiUtils.ShowMessageBox("Warning: No Nvidia GPU was detected. CUDA might fall back to CPU!\n\nTry an NCNN implementation instead if you don't have an Nvidia GPU.", UiUtils.MessageType.Error); diff --git a/Code/Media/FfmpegCommands.cs b/Code/Media/FfmpegCommands.cs index 2241d03..a7fc465 100644 --- a/Code/Media/FfmpegCommands.cs +++ b/Code/Media/FfmpegCommands.cs @@ -25,7 +25,7 @@ namespace Flowframes public static int GetPadding () { - return (Interpolate.currentSettings.ai.AiName == Implementations.flavrCuda.AiName) ? 8 : 2; // FLAVR input needs to be divisible by 8 + return (Interpolate.currentSettings.ai.NameInternal == Implementations.flavrCuda.NameInternal) ? 8 : 2; // FLAVR input needs to be divisible by 8 } public static string GetPadFilter () diff --git a/Code/MiscUtils/ModelDownloadFormUtils.cs b/Code/MiscUtils/ModelDownloadFormUtils.cs index 052272b..7869c43 100644 --- a/Code/MiscUtils/ModelDownloadFormUtils.cs +++ b/Code/MiscUtils/ModelDownloadFormUtils.cs @@ -54,7 +54,7 @@ namespace Flowframes.MiscUtils return; ModelCollection.ModelInfo modelInfo = modelCollection.models[i]; - form.SetStatus($"Downloading files for {modelInfo.ai.AiName.Replace("_", "-")}..."); + form.SetStatus($"Downloading files for {modelInfo.ai.NameInternal.Replace("_", "-")}..."); await ModelDownloader.DownloadModelFiles(ai, modelInfo.dir, false); taskCounter++; UpdateProgressBar(); diff --git a/Code/Os/AiProcess.cs b/Code/Os/AiProcess.cs index 10490f1..583506a 100644 --- a/Code/Os/AiProcess.cs +++ b/Code/Os/AiProcess.cs @@ -617,7 +617,7 @@ namespace Flowframes.Os if (!hasShownError && line.ToLower().Contains("error(s) in loading state_dict")) { hasShownError = true; - string msg = (Interpolate.currentSettings.ai.AiName == Implementations.flavrCuda.AiName) ? "\n\nFor FLAVR, you need to select the correct model for each scale!" : ""; + string msg = (Interpolate.currentSettings.ai.NameInternal == Implementations.flavrCuda.NameInternal) ? "\n\nFor FLAVR, you need to select the correct model for each scale!" : ""; UiUtils.ShowMessageBox($"Error loading the AI model!\n\n{line}{msg}", UiUtils.MessageType.Error); } @@ -645,7 +645,7 @@ namespace Flowframes.Os if (!hasShownError && err && line.Contains("vkAllocateMemory failed")) { hasShownError = true; - bool usingDain = (Interpolate.currentSettings.ai.AiName == Implementations.dainNcnn.AiName); + bool usingDain = (Interpolate.currentSettings.ai.NameInternal == Implementations.dainNcnn.NameInternal); string msg = usingDain ? "\n\nTry reducing the tile size in the AI settings." : "\n\nTry a lower resolution (Settings -> Max Video Size)."; UiUtils.ShowMessageBox($"Vulkan ran out of memory!\n\n{line}{msg}", UiUtils.MessageType.Error); } diff --git a/Code/Ui/InterpolationProgress.cs b/Code/Ui/InterpolationProgress.cs index 35af99a..dd5f370 100644 --- a/Code/Ui/InterpolationProgress.cs +++ b/Code/Ui/InterpolationProgress.cs @@ -79,7 +79,7 @@ namespace Flowframes.Ui { try { - string ncnnStr = I.currentSettings.ai.AiName.Contains("NCNN") ? " done" : ""; + string ncnnStr = I.currentSettings.ai.NameInternal.Contains("NCNN") ? " done" : ""; Regex frameRegex = new Regex($@"(?<=.)\d*(?={I.currentSettings.interpExt}{ncnnStr})"); if (!frameRegex.IsMatch(output)) return; lastFrame = Math.Max(int.Parse(frameRegex.Match(output).Value), lastFrame); diff --git a/Code/Ui/MainUiFunctions.cs b/Code/Ui/MainUiFunctions.cs index c2c28f7..3b0e2f7 100644 --- a/Code/Ui/MainUiFunctions.cs +++ b/Code/Ui/MainUiFunctions.cs @@ -159,7 +159,7 @@ namespace Flowframes.Ui { AI ai = Program.mainForm.GetAi(); - if (ai.AiName == Implementations.rifeNcnn.AiName && !Program.mainForm.GetModel(ai).dir.Contains("v4")) + if (ai.NameInternal == Implementations.rifeNcnn.NameInternal && !Program.mainForm.GetModel(ai).dir.Contains("v4")) { if (factor != 2) Logger.Log($"{ai.FriendlyName} models before 4.0 only support 2x interpolation!"); diff --git a/Code/Ui/UiUtils.cs b/Code/Ui/UiUtils.cs index 9d3c37c..3ac5a21 100644 --- a/Code/Ui/UiUtils.cs +++ b/Code/Ui/UiUtils.cs @@ -60,7 +60,7 @@ namespace Flowframes.Ui } catch (Exception e) { - Logger.Log($"Failed to load available AI models for {ai.AiName}! {e.Message}"); + Logger.Log($"Failed to load available AI models for {ai.NameInternal}! {e.Message}"); Logger.Log($"Stack Trace: {e.StackTrace}", true); }