Refactor Implementations.cs to use enums, backwards compat to GetAi(string name)

This commit is contained in:
N00MKRAD
2024-08-21 14:45:46 +02:00
parent 0b500563ca
commit 713c64e6ec
18 changed files with 120 additions and 92 deletions

View File

@@ -9,7 +9,7 @@ using System.Threading.Tasks;
namespace Flowframes.Data
{
public class AI
public class AiInfo
{
public enum AiBackend { Pytorch, Ncnn, Tensorflow, Other }
public AiBackend Backend { get; set; } = AiBackend.Pytorch;
@@ -26,9 +26,9 @@ namespace Flowframes.Data
public string LogFilename { get { return PkgDir + "-log"; } }
public AI () { }
public AiInfo () { }
public AI(AiBackend backend, string aiName, string longName, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null)
public AiInfo(AiBackend backend, string aiName, string longName, InterpFactorSupport factorSupport = InterpFactorSupport.Fixed, int[] supportedFactors = null)
{
Backend = backend;
NameInternal = aiName;

View File

@@ -6,79 +6,96 @@ namespace Flowframes.Data
{
class Implementations
{
public static AI rifeCuda = new AI()
public enum Ai
{
Backend = AI.AiBackend.Pytorch,
RifeCuda,
RifeNcnn,
RifeNcnnVs,
FlavrCuda,
DainNcnn,
XvfiCuda,
IfrnetNcnn,
}
public static AiInfo rifeCuda = new AiInfo()
{
Backend = AiInfo.AiBackend.Pytorch,
NameInternal = "RIFE_CUDA",
NameLong = "Real-Time Intermediate Flow Estimation",
FactorSupport = AI.InterpFactorSupport.AnyInteger,
SupportedFactors = new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 }
FactorSupport = AiInfo.InterpFactorSupport.AnyInteger,
SupportedFactors = Enumerable.Range(2, 15).ToArray(), // Generates numbers from 2 to 16
};
public static AI rifeNcnn = new AI()
public static AiInfo rifeNcnn = new AiInfo()
{
Backend = AI.AiBackend.Ncnn,
Backend = AiInfo.AiBackend.Ncnn,
NameInternal = "RIFE_NCNN",
NameLong = "Real-Time Intermediate Flow Estimation",
FactorSupport = AI.InterpFactorSupport.AnyFloat,
SupportedFactors = new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 },
FactorSupport = AiInfo.InterpFactorSupport.AnyFloat,
SupportedFactors = Enumerable.Range(2, 15).ToArray(), // Generates numbers from 2 to 16
};
public static AI rifeNcnnVs = new AI()
public static AiInfo rifeNcnnVs = new AiInfo()
{
Backend = AI.AiBackend.Ncnn,
Backend = AiInfo.AiBackend.Ncnn,
NameInternal = "RIFE_NCNN_VS",
NameLong = "Real-Time Intermediate Flow Estimation",
FactorSupport = AI.InterpFactorSupport.AnyFloat,
SupportedFactors = new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 },
FactorSupport = AiInfo.InterpFactorSupport.AnyFloat,
SupportedFactors = Enumerable.Range(2, 15).ToArray(), // Generates numbers from 2 to 16
Piped = true
};
public static AI flavrCuda = new AI()
public static AiInfo flavrCuda = new AiInfo()
{
Backend = AI.AiBackend.Pytorch,
Backend = AiInfo.AiBackend.Pytorch,
NameInternal = "FLAVR_CUDA",
NameLong = "Flow-Agnostic Video Representations",
FactorSupport = AI.InterpFactorSupport.Fixed,
FactorSupport = AiInfo.InterpFactorSupport.Fixed,
SupportedFactors = new int[] { 2, 4, 8 },
};
public static AI dainNcnn = new AI()
public static AiInfo dainNcnn = new AiInfo()
{
Backend = AI.AiBackend.Ncnn,
Backend = AiInfo.AiBackend.Ncnn,
NameInternal = "DAIN_NCNN",
NameLong = "Depth-Aware Video Frame Interpolation",
FactorSupport = AI.InterpFactorSupport.AnyFloat,
SupportedFactors = new int[] { 2, 3, 4, 5, 6, 7, 8 },
FactorSupport = AiInfo.InterpFactorSupport.AnyFloat,
SupportedFactors = Enumerable.Range(2, 7).ToArray(), // Generates numbers from 2 to 8
};
public static AI xvfiCuda = new AI()
public static AiInfo xvfiCuda = new AiInfo()
{
Backend = AI.AiBackend.Pytorch,
Backend = AiInfo.AiBackend.Pytorch,
NameInternal = "XVFI_CUDA",
NameLong = "eXtreme Video Frame Interpolation",
FactorSupport = AI.InterpFactorSupport.AnyInteger,
SupportedFactors = new int[] { 2, 3, 4, 5, 6, 7, 8, 9, 10 },
FactorSupport = AiInfo.InterpFactorSupport.AnyInteger,
SupportedFactors = Enumerable.Range(2, 9).ToArray(), // Generates numbers from 2 to 10
};
public static AI ifrnetNcnn = new AI()
public static AiInfo ifrnetNcnn = new AiInfo()
{
Backend = AI.AiBackend.Ncnn,
Backend = AiInfo.AiBackend.Ncnn,
NameInternal = "IFRNet_NCNN",
NameLong = "Intermediate Feature Refine Network",
FactorSupport = AI.InterpFactorSupport.Fixed,
FactorSupport = AiInfo.InterpFactorSupport.Fixed,
SupportedFactors = new int[] { 2 },
};
public static List<AI> NetworksAll
// Lookup table
private static readonly Dictionary<Ai, AiInfo> AiLookup = new Dictionary<Ai, AiInfo>
{
get
{
return new List<AI> { rifeNcnnVs, rifeNcnn, rifeCuda, flavrCuda, dainNcnn, xvfiCuda, /* ifrnetNcnn */ };
}
}
{ Ai.RifeCuda, rifeCuda },
{ Ai.RifeNcnn, rifeNcnn },
{ Ai.RifeNcnnVs, rifeNcnnVs },
{ Ai.FlavrCuda, flavrCuda },
{ Ai.DainNcnn, dainNcnn },
{ Ai.XvfiCuda, xvfiCuda },
{ Ai.IfrnetNcnn, ifrnetNcnn }
};
public static List<AI> NetworksAvailable
public static List<AiInfo> NetworksAll => AiLookup.Values.ToList();
public static List<AiInfo> NetworksAvailable
{
get
{
@@ -86,21 +103,32 @@ namespace Flowframes.Data
if (pytorchAvailable)
return NetworksAll;
return NetworksAll.Where(x => x.Backend != AI.AiBackend.Pytorch).ToList();
return NetworksAll.Where(x => x.Backend != AiInfo.AiBackend.Pytorch).ToList();
}
}
public static AI GetAi(string aiName)
// Legacy: Get by name
public static AiInfo GetAi(string aiName)
{
foreach (AI ai in NetworksAll)
foreach (var ai in NetworksAll)
{
if (ai.NameInternal == aiName)
return ai;
}
Logger.Log($"AI implementation lookup failed! This should not happen! Please tell the developer! (Implementations.cs)");
Logger.Log($"AI implementation lookup failed for '{aiName}'! This should not happen! Please tell the developer!");
return NetworksAll[0];
}
// New: Use enums
public static AiInfo GetAi(Ai ai)
{
if (AiLookup.TryGetValue(ai, out AiInfo aiObj))
return aiObj;
Logger.Log($"AI implementation lookup failed for '{ai}'! This should not happen! Please tell the developer!");
return NetworksAll[0];
}
}
}
}

View File

@@ -18,7 +18,7 @@ namespace Flowframes
public string inPath;
public string outPath;
public string FullOutPath { get; set; } = "";
public AI ai;
public AiInfo ai;
public string inPixFmt = "yuv420p";
public Fraction inFps;
public Fraction inFpsDetected;
@@ -46,7 +46,7 @@ namespace Flowframes
public InterpSettings() { }
public InterpSettings(string inPathArg, string outPathArg, AI aiArg, Fraction inFpsDetectedArg, Fraction inFpsArg, float interpFactorArg, float itsScale, OutputSettings outSettingsArg, ModelCollection.ModelInfo modelArg)
public InterpSettings(string inPathArg, string outPathArg, AiInfo aiArg, Fraction inFpsDetectedArg, Fraction inFpsArg, float interpFactorArg, float itsScale, OutputSettings outSettingsArg, ModelCollection.ModelInfo modelArg)
{
inPath = inPathArg;
outPath = outPathArg;
@@ -197,7 +197,7 @@ namespace Flowframes
public enum FrameType { Import, Interp, Both };
public void RefreshExtensions(FrameType type = FrameType.Both, AI ai = null)
public void RefreshExtensions(FrameType type = FrameType.Both, AiInfo ai = null)
{
if(ai == null)
{

View File

@@ -9,12 +9,12 @@ namespace Flowframes.Data
{
public class ModelCollection
{
public AI Ai { get; set; } = null;
public AiInfo Ai { get; set; } = null;
public List<ModelInfo> Models { get; set; } = new List<ModelInfo>();
public class ModelInfo
{
public AI Ai { get; set; } = null;
public AiInfo Ai { get; set; } = null;
public string Name { get; set; } = "";
public string Desc { get; set; } = "";
public string Dir { get; set; } = "";
@@ -36,12 +36,12 @@ namespace Flowframes.Data
}
}
public ModelCollection(AI ai)
public ModelCollection(AiInfo ai)
{
Ai = ai;
}
public ModelCollection(AI ai, string jsonContentOrPath)
public ModelCollection(AiInfo ai, string jsonContentOrPath)
{
Ai = ai;

View File

@@ -341,7 +341,7 @@
<Reference Include="System.Management.Automation" />
</ItemGroup>
<ItemGroup>
<Compile Include="Data\AI.cs" />
<Compile Include="Data\AiInfo.cs" />
<Compile Include="Data\AudioTrack.cs" />
<Compile Include="Data\EncoderInfo.cs" />
<Compile Include="Data\EncoderInfoVideo.cs" />

View File

@@ -35,7 +35,7 @@ namespace Flowframes.Forms.Main
}
}
public ModelCollection.ModelInfo GetModel(AI currentAi)
public ModelCollection.ModelInfo GetModel(AiInfo currentAi)
{
try
{
@@ -47,11 +47,11 @@ namespace Flowframes.Forms.Main
}
}
public AI GetAi()
public AiInfo GetAi()
{
try
{
foreach (AI ai in Implementations.NetworksAll)
foreach (AiInfo ai in Implementations.NetworksAll)
{
if (GetAiComboboxName(ai) == aiCombox.Text)
return ai;

View File

@@ -229,7 +229,7 @@ namespace Flowframes.Forms.Main
public InterpSettings GetCurrentSettings()
{
SetTab(interpOptsTab.Name);
AI ai = GetAi();
AiInfo ai = GetAi();
var s = new InterpSettings()
{
@@ -357,7 +357,7 @@ namespace Flowframes.Forms.Main
{
bool pytorchAvailable = Python.IsPytorchReady();
foreach (AI ai in Implementations.NetworksAvailable)
foreach (AiInfo ai in Implementations.NetworksAvailable)
aiCombox.Items.Add(GetAiComboboxName(ai));
string lastUsedAiName = Config.Get(Config.Key.lastUsedAiName);
@@ -366,7 +366,7 @@ namespace Flowframes.Forms.Main
Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal);
}
private string GetAiComboboxName(AI ai)
private string GetAiComboboxName(AiInfo ai)
{
return ai.FriendlyName + " - " + ai.Description;
}
@@ -556,7 +556,7 @@ namespace Flowframes.Forms.Main
Config.Set(Config.Key.lastUsedAiName, GetAi().NameInternal);
interpFactorCombox_SelectedIndexChanged(null, null);
fpsOutTbox.ReadOnly = GetAi().FactorSupport != AI.InterpFactorSupport.AnyFloat;
fpsOutTbox.ReadOnly = GetAi().FactorSupport != AiInfo.InterpFactorSupport.AnyFloat;
}
public void UpdateAiModelCombox()

View File

@@ -132,7 +132,7 @@ namespace Flowframes.IO
return modelFiles;
}
public static async Task DownloadModelFiles (AI ai, string modelDir, bool log = true)
public static async Task DownloadModelFiles (AiInfo ai, string modelDir, bool log = true)
{
string aiDir = ai.PkgDir;
@@ -198,7 +198,7 @@ namespace Flowframes.IO
{
List<string> modelPaths = new List<string>();
foreach (AI ai in Implementations.NetworksAll)
foreach (AiInfo ai in Implementations.NetworksAll)
{
string aiPkgFolder = Path.Combine(Paths.GetPkgPath(), ai.PkgDir);
ModelCollection aiModels = AiModels.GetModels(ai);

View File

@@ -12,7 +12,7 @@ namespace Flowframes.Main
{
class AiModels
{
public static ModelCollection GetModels (AI ai)
public static ModelCollection GetModels (AiInfo ai)
{
string pkgPath = Path.Combine(Paths.GetPkgPath(), ai.PkgDir);
string modelsFile = Path.Combine(pkgPath, "models.json");
@@ -35,7 +35,7 @@ namespace Flowframes.Main
return modelCollection;
}
public static List<string> GetCustomModels(AI ai)
public static List<string> GetCustomModels(AiInfo ai)
{
string pkgPath = Path.Combine(Paths.GetPkgPath(), ai.PkgDir);
List<string> custModels = new List<string>();
@@ -49,7 +49,7 @@ namespace Flowframes.Main
return custModels;
}
public static ModelCollection.ModelInfo GetModelByName(AI ai, string modelName)
public static ModelCollection.ModelInfo GetModelByName(AiInfo ai, string modelName)
{
ModelCollection modelCollection = GetModels(ai);
@@ -62,7 +62,7 @@ namespace Flowframes.Main
return null;
}
public static ModelCollection.ModelInfo GetModelByDir(AI ai, string dirName)
public static ModelCollection.ModelInfo GetModelByDir(AiInfo ai, string dirName)
{
ModelCollection modelCollection = GetModels(ai);

View File

@@ -36,10 +36,10 @@ namespace Flowframes.Main
safetyBufferFrames = 90;
if (Interpolate.currentSettings.ai.Backend == AI.AiBackend.Ncnn)
if (Interpolate.currentSettings.ai.Backend == AiInfo.AiBackend.Ncnn)
safetyBufferFrames = Config.GetInt(Config.Key.autoEncSafeBufferNcnn, 150);
if (Interpolate.currentSettings.ai.Backend == AI.AiBackend.Pytorch)
if (Interpolate.currentSettings.ai.Backend == AiInfo.AiBackend.Pytorch)
safetyBufferFrames = Config.GetInt(Config.Key.autoEncSafeBufferCuda, 90);
}

View File

@@ -194,7 +194,7 @@ namespace Flowframes
}
}
public static async Task RunAi(string outpath, AI ai, bool stepByStep = false)
public static async Task RunAi(string outpath, AiInfo ai, bool stepByStep = false)
{
if (canceled) return;

View File

@@ -154,7 +154,7 @@ namespace Flowframes.Main
}
}
public static bool CheckAiAvailable(AI ai, ModelCollection.ModelInfo model)
public static bool CheckAiAvailable(AiInfo ai, ModelCollection.ModelInfo model)
{
if (IoUtils.GetAmountOfFiles(Path.Combine(Paths.GetPkgPath(), ai.PkgDir), true) < 1)
{
@@ -195,7 +195,7 @@ namespace Flowframes.Main
return true;
}
public static void ShowWarnings(float factor, AI ai)
public static void ShowWarnings(float factor, AiInfo ai)
{
if (Config.GetInt(Config.Key.cmdDebugMode) > 0)
Logger.Log($"Warning: The CMD window for interpolation is enabled. This will disable Auto-Encode and the progress bar!");

View File

@@ -19,7 +19,7 @@ namespace Flowframes.MiscUtils
{
form.SetDownloadBtnEnabled(true);
canceled = false;
List<AI> ais = new List<AI>();
List<AiInfo> ais = new List<AiInfo>();
if (rifeC) ais.Add(Implementations.rifeCuda);
if (rifeN) ais.Add(Implementations.rifeNcnn);
@@ -36,7 +36,7 @@ namespace Flowframes.MiscUtils
await Task.Delay(10);
UpdateProgressBar();
foreach (AI ai in ais)
foreach (AiInfo ai in ais)
await DownloadForAi(ai);
form.SetWorking(false);
@@ -44,7 +44,7 @@ namespace Flowframes.MiscUtils
form.SetDownloadBtnEnabled(false);
}
public static async Task DownloadForAi(AI ai)
public static async Task DownloadForAi(AiInfo ai)
{
ModelCollection modelCollection = AiModels.GetModels(ai);
@@ -72,11 +72,11 @@ namespace Flowframes.MiscUtils
ModelDownloader.canceled = true;
}
public static int GetTaskCount (List<AI> ais)
public static int GetTaskCount (List<AiInfo> ais)
{
int count = 0;
foreach(AI ai in ais)
foreach(AiInfo ai in ais)
{
ModelCollection modelCollection = AiModels.GetModels(ai);
count += modelCollection.Models.Count;

View File

@@ -157,7 +157,7 @@ namespace Flowframes.Os
public static async Task RunRifeCuda(string framesPath, float interpFactor, string mdl)
{
AI ai = Implementations.rifeCuda;
AiInfo ai = Implementations.rifeCuda;
if (Interpolate.currentlyUsingAutoEnc) // Ensure AutoEnc is not paused
AutoEncode.paused = false;
@@ -232,7 +232,7 @@ namespace Flowframes.Os
public static async Task RunFlavrCuda(string framesPath, float interpFactor, string mdl)
{
AI ai = Implementations.flavrCuda;
AiInfo ai = Implementations.flavrCuda;
if (Interpolate.currentlyUsingAutoEnc) // Ensure AutoEnc is not paused
AutoEncode.paused = false;
@@ -292,7 +292,7 @@ namespace Flowframes.Os
public static async Task RunRifeNcnn(string framesPath, string outPath, float factor, string mdl)
{
AI ai = Implementations.rifeNcnn;
AiInfo ai = Implementations.rifeNcnn;
processTimeMulti.Restart();
try
@@ -350,7 +350,7 @@ namespace Flowframes.Os
{
if (Interpolate.canceled) return;
AI ai = Implementations.rifeNcnnVs;
AiInfo ai = Implementations.rifeNcnnVs;
processTimeMulti.Restart();
try
@@ -429,7 +429,7 @@ namespace Flowframes.Os
public static async Task RunDainNcnn(string framesPath, string outPath, float factor, string mdl, int tilesize)
{
AI ai = Implementations.dainNcnn;
AiInfo ai = Implementations.dainNcnn;
if (Interpolate.currentlyUsingAutoEnc) // Ensure AutoEnc is not paused
AutoEncode.paused = false;
@@ -484,7 +484,7 @@ namespace Flowframes.Os
public static async Task RunXvfiCuda(string framesPath, float interpFactor, string mdl)
{
AI ai = Implementations.xvfiCuda;
AiInfo ai = Implementations.xvfiCuda;
if (Interpolate.currentlyUsingAutoEnc) // Ensure AutoEnc is not paused
AutoEncode.paused = false;
@@ -548,7 +548,7 @@ namespace Flowframes.Os
public static async Task RunIfrnetNcnn(string framesPath, string outPath, float factor, string mdl)
{
AI ai = Implementations.ifrnetNcnn;
AiInfo ai = Implementations.ifrnetNcnn;
processTimeMulti.Restart();
@@ -601,7 +601,7 @@ namespace Flowframes.Os
while (!ifrnetNcnn.HasExited) await Task.Delay(1);
}
static void LogOutput(string line, AI ai, bool err = false)
static void LogOutput(string line, AiInfo ai, bool err = false)
{
if (string.IsNullOrWhiteSpace(line) || line.Length < 6)
return;
@@ -614,7 +614,7 @@ namespace Flowframes.Os
string lastLogLines = string.Join("\n", Logger.GetSessionLogLastLines(lastLogName, 6).Select(x => $"[{x.Split("]: [").Skip(1).FirstOrDefault()}"));
if (ai.Backend == AI.AiBackend.Pytorch) // Pytorch specific
if (ai.Backend == AiInfo.AiBackend.Pytorch) // Pytorch specific
{
if (line.Contains("ff:nocuda-cpu"))
Logger.Log("WARNING: CUDA-capable GPU device is not available, running on CPU instead!");
@@ -651,7 +651,7 @@ namespace Flowframes.Os
}
}
if (ai.Backend == AI.AiBackend.Ncnn) // NCNN specific
if (ai.Backend == AiInfo.AiBackend.Ncnn) // NCNN specific
{
if (!hasShownError && err && line.MatchesWildcard("vk*Instance* failed"))
{

View File

@@ -158,7 +158,7 @@ namespace Flowframes.Os
if (!Config.GetBool("fetchModelsFromRepo", false))
return;
foreach (AI ai in Implementations.NetworksAll)
foreach (AiInfo ai in Implementations.NetworksAll)
{
try
{

View File

@@ -155,7 +155,7 @@ namespace Flowframes.Ui
public static float ValidateInterpFactor (float factor)
{
AI ai = Program.mainForm.GetAi();
AiInfo ai = Program.mainForm.GetAi();
if (ai.NameInternal == Implementations.rifeNcnn.NameInternal && !Program.mainForm.GetModel(ai).Dir.Contains("v4"))
{
@@ -165,23 +165,23 @@ namespace Flowframes.Ui
return 2;
}
if (ai.FactorSupport == AI.InterpFactorSupport.Fixed)
if (ai.FactorSupport == AiInfo.InterpFactorSupport.Fixed)
{
int closest = ai.SupportedFactors.Min(i => (Math.Abs(factor.RoundToInt() - i), i)).i;
return (float)closest;
}
if(ai.FactorSupport == AI.InterpFactorSupport.AnyPowerOfTwo)
if(ai.FactorSupport == AiInfo.InterpFactorSupport.AnyPowerOfTwo)
{
return ToNearestPow2(factor.RoundToInt()).Clamp(2, 128);
}
if(ai.FactorSupport == AI.InterpFactorSupport.AnyInteger)
if(ai.FactorSupport == AiInfo.InterpFactorSupport.AnyInteger)
{
return factor.RoundToInt().Clamp(2, 128);
}
if(ai.FactorSupport == AI.InterpFactorSupport.AnyFloat)
if(ai.FactorSupport == AiInfo.InterpFactorSupport.AnyFloat)
{
return factor.Clamp(2, 128);
}

View File

@@ -30,7 +30,7 @@ namespace Flowframes.Ui
return true;
}
public static ComboBox LoadAiModelsIntoGui(ComboBox combox, AI ai)
public static ComboBox LoadAiModelsIntoGui(ComboBox combox, AiInfo ai)
{
combox.Items.Clear();

View File

@@ -14,7 +14,7 @@ namespace Flowframes.Utilities
{
class NcnnUtils
{
public static int GetRifeNcnnGpuThreads(Size res, int gpuId, AI ai)
public static int GetRifeNcnnGpuThreads(Size res, int gpuId, AiInfo ai)
{
int threads = Config.GetInt(Config.Key.ncnnThreads);
int maxThreads = VulkanUtils.GetMaxNcnnThreads(gpuId);
@@ -41,7 +41,7 @@ namespace Flowframes.Utilities
return tilesizeStr;
}
public static string GetNcnnThreads(AI ai)
public static string GetNcnnThreads(AiInfo ai)
{
List<int> enabledGpuIds = Config.Get(Config.Key.ncnnGpus).Split(',').Select(s => s.GetInt()).ToList(); // Get GPU IDs
List<int> gpuThreadCounts = enabledGpuIds.Select(g => GetRifeNcnnGpuThreads(new Size(), g, ai)).ToList(); // Get max thread count for each GPU