Files
Mangio-RVC-Fork/tensorlowest.py

129 lines
3.7 KiB
Python
Raw Permalink Normal View History

2023-07-30 16:00:56 +07:00
from tensorboard.backend.event_processing import event_accumulator
import os
from shutil import copy2
from re import search as RSearch
import pandas as pd
from ast import literal_eval as LEval
2023-08-02 00:13:59 +00:00
weights_dir = "weights/"
2023-07-30 16:00:56 +07:00
def find_biggest_tensorboard(tensordir):
try:
2023-08-02 00:13:59 +00:00
files = [f for f in os.listdir(tensordir) if f.endswith(".0")]
2023-07-30 16:00:56 +07:00
if not files:
print("No files with the '.0' extension found!")
return
max_size = 0
biggest_file = ""
for file in files:
file_path = os.path.join(tensordir, file)
if os.path.isfile(file_path):
file_size = os.path.getsize(file_path)
if file_size > max_size:
max_size = file_size
biggest_file = file
return biggest_file
except FileNotFoundError:
print("Couldn't find your model!")
return
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
def main(model_name, save_freq, lastmdls):
global lowestval_weight_dir, scl
2023-08-02 00:13:59 +00:00
tensordir = os.path.join("logs", model_name)
2023-07-30 16:00:56 +07:00
lowestval_weight_dir = os.path.join(tensordir, "lowestvals")
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
latest_file = find_biggest_tensorboard(tensordir)
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
if latest_file is None:
print("Couldn't find a valid tensorboard file!")
return
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
tfile = os.path.join(tensordir, latest_file)
2023-08-02 00:13:59 +00:00
ea = event_accumulator.EventAccumulator(
tfile,
size_guidance={
2023-08-02 00:13:59 +00:00
event_accumulator.COMPRESSED_HISTOGRAMS: 500,
event_accumulator.IMAGES: 4,
event_accumulator.AUDIO: 4,
event_accumulator.SCALARS: 0,
event_accumulator.HISTOGRAMS: 1,
},
)
2023-07-30 16:00:56 +07:00
ea.Reload()
ea.Tags()
2023-08-02 00:13:59 +00:00
scl = ea.Scalars("loss/g/total")
2023-07-30 16:00:56 +07:00
listwstep = {}
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
for val in scl:
if (val.step // save_freq) * save_freq in [val.step for val in scl]:
listwstep[float(val.value)] = (val.step // save_freq) * save_freq
lowest_vals = sorted(listwstep.keys())[:lastmdls]
2023-08-02 00:13:59 +00:00
sorted_dict = {
value: step for value, step in listwstep.items() if value in lowest_vals
}
2023-07-30 16:00:56 +07:00
return sorted_dict
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
def selectweights(model_name, file_dict, weights_dir, lowestval_weight_dir):
os.makedirs(lowestval_weight_dir, exist_ok=True)
logdir = []
files = []
2023-08-02 00:13:59 +00:00
lbldict = {"Values": {}, "Names": {}}
2023-07-30 16:00:56 +07:00
weights_dir_path = os.path.join(weights_dir, "")
low_val_path = os.path.join(os.getcwd(), os.path.join(lowestval_weight_dir, ""))
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
try:
file_dict = LEval(file_dict)
2023-08-02 00:13:59 +00:00
except Exception as e:
2023-07-30 16:00:56 +07:00
print(f"Error! {e}")
return f"Couldn't load tensorboard file! {e}"
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
weights = [f for f in os.scandir(weights_dir)]
for key, value in file_dict.items():
2023-08-02 00:13:59 +00:00
pattern = rf"^{model_name}_.*_s{value}\.pth$"
matching_weights = [
f.name for f in weights if f.is_file() and RSearch(pattern, f.name)
]
2023-07-30 16:00:56 +07:00
for weight in matching_weights:
source_path = weights_dir_path + weight
destination_path = os.path.join(lowestval_weight_dir, weight)
2023-08-02 00:13:59 +00:00
2023-07-30 16:00:56 +07:00
copy2(source_path, destination_path)
logdir.append(f"File = {weight} Value: {key}, Step: {value}")
2023-08-02 00:13:59 +00:00
lbldict["Names"][weight] = weight
lbldict["Values"][weight] = key
2023-07-30 16:00:56 +07:00
files.append(low_val_path + weight)
print(f"File = {weight} Value: {key}, Step: {value}")
2023-08-02 00:13:59 +00:00
yield ("\n".join(logdir), files, pd.DataFrame(lbldict))
return "".join(logdir), files, pd.DataFrame(lbldict)
2023-07-30 16:00:56 +07:00
if __name__ == "__main__":
model = str(input("Enter the name of the model: "))
sav_freq = int(input("Enter save frequency of the model: "))
ds = main(model, sav_freq)
2023-08-02 00:13:59 +00:00
if ds:
selectweights(model, ds, weights_dir, lowestval_weight_dir)