Completely revamped model downloading and file checking

txt -> json
md5 -> crc32
re-check validity after downloading
This commit is contained in:
N00MKRAD
2021-06-20 18:30:59 +02:00
parent eef5f0354a
commit 54c69e4e19

View File

@@ -1,6 +1,7 @@
using Flowframes.Data;
using Flowframes.Main;
using Flowframes.MiscUtils;
using Newtonsoft.Json;
using System;
using System.Collections.Generic;
using System.Diagnostics;
@@ -15,23 +16,15 @@ namespace Flowframes.IO
class ModelDownloader
{
public static async Task<Dictionary<string, string>> GetFilelist (string ai, string model)
{
var client = new WebClient();
string[] fileLines = client.DownloadString(GetMdlFileUrl(ai, model, "md5.txt")).SplitIntoLines();
Dictionary<string, string> filesDict = GetDict(fileLines);
return filesDict;
}
static string GetMdlUrl (string ai, string model)
static string GetMdlUrl (string ai, string relPath)
{
string baseUrl = Config.Get(Config.Key.mdlBaseUrl);
return Path.Combine(baseUrl, ai.ToLower(), model);
return Path.Combine(baseUrl, ai.ToLower(), relPath);
}
static string GetMdlFileUrl(string ai, string model, string file)
static string GetMdlFileUrl(string ai, string model, string relPath)
{
return Path.Combine(GetMdlUrl(ai, model), file);
return Path.Combine(GetMdlUrl(ai, model), relPath);
}
static string GetLocalPath(string ai, string model)
@@ -39,16 +32,22 @@ namespace Flowframes.IO
return Path.Combine(Paths.GetPkgPath(), ai, model);
}
static async Task DownloadTo (string url, string saveDir, int retries = 3)
static async Task DownloadTo (string url, string saveDirOrPath, int retries = 3)
{
string savePath = Path.Combine(saveDir, Path.GetFileName(url));
string savePath = saveDirOrPath;
if (IOUtils.IsPathDirectory(saveDirOrPath))
savePath = Path.Combine(saveDirOrPath, Path.GetFileName(url));
IOUtils.TryDeleteIfExists(savePath);
Directory.CreateDirectory(Path.GetDirectoryName(savePath));
Logger.Log($"Downloading '{url}' to '{savePath}'", true);
Stopwatch sw = new Stopwatch();
sw.Restart();
bool completed = false;
int lastProgPercentage = -1;
var client = new WebClient();
client.DownloadProgressChanged += (sender, args) =>
{
if (sw.ElapsedMilliseconds > 200 && args.ProgressPercentage != lastProgPercentage)
@@ -64,7 +63,9 @@ namespace Flowframes.IO
Logger.Log("Download failed: " + args.Error.Message);
completed = true;
};
client.DownloadFileTaskAsync(url, savePath).ConfigureAwait(false);
while (!completed)
{
if (Interpolate.canceled)
@@ -78,7 +79,7 @@ namespace Flowframes.IO
client.CancelAsync();
if(retries > 0)
{
await DownloadTo(url, saveDir, retries--);
await DownloadTo(url, saveDirOrPath, retries--);
}
else
{
@@ -91,6 +92,38 @@ namespace Flowframes.IO
Logger.Log($"Downloaded '{Path.GetFileName(url)}' ({IOUtils.GetFilesize(savePath) / 1024} KB)", true);
}
class ModelFile
{
public string filename;
public string dir;
public long size;
public string crc32;
}
static List<ModelFile> GetModelFilesFromJson (string json)
{
List<ModelFile> modelFiles = new List<ModelFile>();
try
{
dynamic data = JsonConvert.DeserializeObject(json);
foreach (var item in data)
{
string dirString = ((string)item.dir).Replace(@"\", @"/");
if (dirString.Length > 0 && dirString[0] == '/') dirString = dirString.Remove(0, 1);
long sizeLong = long.Parse((string)item.size);
modelFiles.Add(new ModelFile { filename = item.filename, dir = dirString, size = sizeLong, crc32 = item.crc32 });
}
}
catch (Exception e)
{
Logger.Log($"Failed to parse model file list from JSON: {e.Message}", true);
}
return modelFiles;
}
public static async Task DownloadModelFiles (AI ai, string modelDir)
{
string aiDir = ai.pkgDir;
@@ -105,13 +138,27 @@ namespace Flowframes.IO
Logger.Log($"Downloading '{modelDir}' model files...");
Directory.CreateDirectory(mdlDir);
await DownloadTo(GetMdlFileUrl(aiDir, modelDir, "md5.txt"), mdlDir);
Dictionary<string, string> fileList = await GetFilelist(aiDir, modelDir);
foreach (KeyValuePair<string, string> modelFile in fileList)
await DownloadTo(GetMdlFileUrl(aiDir, modelDir, modelFile.Key), mdlDir);
await DownloadTo(GetMdlFileUrl(aiDir, modelDir, "files.json"), mdlDir);
List<ModelFile> modelFiles = GetModelFilesFromJson(File.ReadAllText(Path.Combine(mdlDir, "files.json")));
if (modelFiles.Count < 1)
{
Interpolate.Cancel($"Error: Can't download model files because no entries were loaded from files.json. Please try again.");
return;
}
foreach (ModelFile mf in modelFiles)
{
string relPath = Path.Combine(mf.dir, mf.filename).Replace("\\", "/");
await DownloadTo(GetMdlFileUrl(aiDir, modelDir, relPath), Path.Combine(mdlDir, relPath));
}
Logger.Log($"Downloaded \"{modelDir}\" model files.", false, true);
if (!AreFilesValid(aiDir, modelDir))
Interpolate.Cancel($"Model files are invalid! Please try again.");
}
catch (Exception e)
{
@@ -161,10 +208,11 @@ namespace Flowframes.IO
return false;
}
if (Debugger.IsAttached) // Disable MD5 check in dev environment
return true;
// TODO UNCOMMENT
//if (Debugger.IsAttached) // Disable MD5 check in dev environment
// return true;
string md5FilePath = Path.Combine(mdlDir, "md5.txt");
string md5FilePath = Path.Combine(mdlDir, "files.json");
if (!File.Exists(md5FilePath) || IOUtils.GetFilesize(md5FilePath) < 32)
{
@@ -172,15 +220,21 @@ namespace Flowframes.IO
return false;
}
string[] md5Lines = IOUtils.ReadLines(md5FilePath);
Dictionary<string, string> filesDict = GetDict(md5Lines);
List<ModelFile> modelFiles = GetModelFilesFromJson(File.ReadAllText(Path.Combine(mdlDir, "files.json")));
foreach(KeyValuePair<string, string> file in filesDict)
if (modelFiles.Count < 1)
{
string md5 = IOUtils.GetHash(Path.Combine(mdlDir, file.Key), IOUtils.Hash.MD5);
if (md5.Trim() != file.Value.Trim())
Logger.Log($"Files for model {model} not valid: JSON contains {modelFiles.Count} entries.", true);
return false;
}
foreach (ModelFile mf in modelFiles)
{
string crc = IOUtils.GetHash(Path.Combine(mdlDir, mf.dir, mf.filename), IOUtils.Hash.CRC32);
if (crc.Trim() != mf.crc32.Trim())
{
Logger.Log($"Files for model {model} not valid: MD5 of {file.Key} ({md5.Trim()}) does not equal validation MD5 ({file.Value.Trim()}).", true);
Logger.Log($"Files for model {model} not valid: CRC32 of {mf.filename} ({crc.Trim()}) does not equal validation CRC32 ({mf.crc32.Trim()}).", true);
return false;
}
}