diff --git a/infer-web.py b/infer-web.py index ff88efa..faa15ed 100644 --- a/infer-web.py +++ b/infer-web.py @@ -2,6 +2,7 @@ import os import shutil import sys import json # Mangio fork using json for preset saving +import math import signal @@ -943,6 +944,23 @@ def change_f0(if_f0_3, sr2, version19): # f0method8,pretrained_G14,pretrained_D ) +global log_interval + + +def set_log_interval(exp_dir, batch_size12): + log_interval = 1 + + folder_path = os.path.join(exp_dir, "1_16k_wavs") + + if os.path.exists(folder_path) and os.path.isdir(folder_path): + wav_files = [f for f in os.listdir(folder_path) if f.endswith(".wav")] + if wav_files: + sample_size = len(wav_files) + log_interval = math.ceil(sample_size / batch_size12) + + return log_interval + + # but3.click(click_train,[exp_dir1,sr2,if_f0_3,save_epoch10,total_epoch11,batch_size12,if_save_latest13,pretrained_G14,pretrained_D15,gpus16]) def click_train( exp_dir1, @@ -969,6 +987,9 @@ def click_train( if version19 == "v1" else "%s/3_feature768" % (exp_dir) ) + + log_interval = set_log_interval(exp_dir, batch_size12) + if if_f0_3: f0_dir = "%s/2a_f0" % (exp_dir) f0nsf_dir = "%s/2b-f0nsf" % (exp_dir) @@ -1038,7 +1059,7 @@ def click_train( #### cmd = ( config.python_cmd - + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s" + + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -g %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -li %s" % ( exp_dir1, sr2, @@ -1053,12 +1074,13 @@ def click_train( 1 if if_cache_gpu17 == i18n("是") else 0, 1 if if_save_every_weights18 == i18n("是") else 0, version19, + log_interval, ) ) else: cmd = ( config.python_cmd - + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s" + + " train_nsf_sim_cache_sid_load_pretrain.py -e %s -sr %s -f0 %s -bs %s -te %s -se %s %s %s -l %s -c %s -sw %s -v %s -li %s" % ( exp_dir1, sr2, @@ -1072,6 +1094,7 @@ def click_train( 1 if if_cache_gpu17 == i18n("是") else 0, 1 if if_save_every_weights18 == i18n("是") else 0, version19, + log_interval, ) ) print(cmd) diff --git a/train/utils.py b/train/utils.py index 8884e43..783f251 100644 --- a/train/utils.py +++ b/train/utils.py @@ -352,6 +352,9 @@ def get_hparams(init=True): required=True, help="if caching the dataset in GPU memory, 1 or 0", ) + parser.add_argument( + "-li", "--log_interval", type=int, required=True, help="log interval" + ) args = parser.parse_args() name = args.experiment_dir @@ -391,6 +394,16 @@ def get_hparams(init=True): hparams.save_every_weights = args.save_every_weights hparams.if_cache_data_in_gpu = args.if_cache_data_in_gpu hparams.data.training_files = "%s/filelist.txt" % experiment_dir + + hparams.train.log_interval = args.log_interval + + # Update log_interval in the 'train' section of the config dictionary + config["train"]["log_interval"] = args.log_interval + + # Save the updated config back to the config_save_path + with open(config_save_path, "w") as f: + json.dump(config, f, indent=4) + return hparams