diff --git a/TTS/tts/utils/generic_utils.py b/TTS/tts/utils/generic_utils.py index c58f37c9..d2725eee 100644 --- a/TTS/tts/utils/generic_utils.py +++ b/TTS/tts/utils/generic_utils.py @@ -242,32 +242,32 @@ def check_config_tts(c): check_argument("trim_db", c["audio"], restricted=True, val_type=int) # training parameters - check_argument("batch_size", c, restricted=True, val_type=int, min_val=1) - check_argument("eval_batch_size", c, restricted=True, val_type=int, min_val=1) - check_argument("r", c, restricted=True, val_type=int, min_val=1) - check_argument("gradual_training", c, restricted=False, val_type=list) - check_argument("mixed_precision", c, restricted=False, val_type=bool) + # check_argument('batch_size', c, restricted=True, val_type=int, min_val=1) + # check_argument('eval_batch_size', c, restricted=True, val_type=int, min_val=1) + check_argument('r', c, restricted=True, val_type=int, min_val=1) + check_argument('gradual_training', c, restricted=False, val_type=list) + # check_argument('mixed_precision', c, restricted=False, val_type=bool) # check_argument('grad_accum', c, restricted=True, val_type=int, min_val=1, max_val=100) # loss parameters - check_argument("loss_masking", c, restricted=True, val_type=bool) - if c["model"].lower() in ["tacotron", "tacotron2"]: - check_argument("decoder_loss_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("postnet_loss_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("postnet_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("decoder_diff_spec_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("decoder_ssim_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("postnet_ssim_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("ga_alpha", c, restricted=True, val_type=float, min_val=0) - if c["model"].lower in ["speedy_speech", "align_tts"]: - check_argument("ssim_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("l1_alpha", c, restricted=True, val_type=float, min_val=0) - check_argument("huber_alpha", c, restricted=True, val_type=float, min_val=0) + # check_argument('loss_masking', c, restricted=True, val_type=bool) + if c['model'].lower() in ['tacotron', 'tacotron2']: + check_argument('decoder_loss_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_loss_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('decoder_diff_spec_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('decoder_ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('postnet_ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('ga_alpha', c, restricted=True, val_type=float, min_val=0) + if c['model'].lower in ["speedy_speech", "align_tts"]: + check_argument('ssim_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('l1_alpha', c, restricted=True, val_type=float, min_val=0) + check_argument('huber_alpha', c, restricted=True, val_type=float, min_val=0) # validation parameters - check_argument("run_eval", c, restricted=True, val_type=bool) - check_argument("test_delay_epochs", c, restricted=True, val_type=int, min_val=0) - check_argument("test_sentences_file", c, restricted=False, val_type=str) + # check_argument('run_eval', c, restricted=True, val_type=bool) + # check_argument('test_delay_epochs', c, restricted=True, val_type=int, min_val=0) + # check_argument('test_sentences_file', c, restricted=False, val_type=str) # optimizer check_argument("noam_schedule", c, restricted=False, val_type=bool) @@ -319,24 +319,23 @@ def check_config_tts(c): check_argument("encoder_type", c, restricted=not is_tacotron(c), val_type=str) # tensorboard - check_argument("print_step", c, restricted=True, val_type=int, min_val=1) - check_argument("tb_plot_step", c, restricted=True, val_type=int, min_val=1) - check_argument("save_step", c, restricted=True, val_type=int, min_val=1) - check_argument("checkpoint", c, restricted=True, val_type=bool) - check_argument("tb_model_param_stats", c, restricted=True, val_type=bool) + # check_argument('print_step', c, restricted=True, val_type=int, min_val=1) + # check_argument('tb_plot_step', c, restricted=True, val_type=int, min_val=1) + # check_argument('save_step', c, restricted=True, val_type=int, min_val=1) + # check_argument('checkpoint', c, restricted=True, val_type=bool) + # check_argument('tb_model_param_stats', c, restricted=True, val_type=bool) # dataloading # pylint: disable=import-outside-toplevel from TTS.tts.utils.text import cleaners - - check_argument("text_cleaner", c, restricted=True, val_type=str, enum_list=dir(cleaners)) - check_argument("enable_eos_bos_chars", c, restricted=True, val_type=bool) - check_argument("num_loader_workers", c, restricted=True, val_type=int, min_val=0) - check_argument("num_val_loader_workers", c, restricted=True, val_type=int, min_val=0) - check_argument("batch_group_size", c, restricted=True, val_type=int, min_val=0) - check_argument("min_seq_len", c, restricted=True, val_type=int, min_val=0) - check_argument("max_seq_len", c, restricted=True, val_type=int, min_val=10) - check_argument("compute_input_seq_cache", c, restricted=True, val_type=bool) + # check_argument('text_cleaner', c, restricted=True, val_type=str, enum_list=dir(cleaners)) + # check_argument('enable_eos_bos_chars', c, restricted=True, val_type=bool) + # check_argument('num_loader_workers', c, restricted=True, val_type=int, min_val=0) + # check_argument('num_val_loader_workers', c, restricted=True, val_type=int, min_val=0) + # check_argument('batch_group_size', c, restricted=True, val_type=int, min_val=0) + # check_argument('min_seq_len', c, restricted=True, val_type=int, min_val=0) + # check_argument('max_seq_len', c, restricted=True, val_type=int, min_val=10) + # check_argument('compute_input_seq_cache', c, restricted=True, val_type=bool) # paths check_argument("output_path", c, restricted=True, val_type=str) diff --git a/TTS/utils/generic_utils.py b/TTS/utils/generic_utils.py index 140cf811..a3a604df 100644 --- a/TTS/utils/generic_utils.py +++ b/TTS/utils/generic_utils.py @@ -5,6 +5,7 @@ import shutil import subprocess import sys from pathlib import Path +from typing import List def get_git_branch(): @@ -139,7 +140,21 @@ class KeepAverage: self.update_value(key, value) -def check_argument(name, c, enum_list=None, max_val=None, min_val=None, restricted=False, alternative=None, allow_none=False): +def check_argument(name, + c, + prerequest=None, + enum_list=None, + max_val=None, + min_val=None, + restricted=False, + alternative=None, + allow_none=False): + if isinstance(prerequest, List()): + if any([f not in c.keys() for f in prerequest]): + return + else: + if prerequest not in c.keys(): + return if alternative in c.keys() and c[alternative] is not None: return if allow_none and c[name] is None: