diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 99b8bba5..0af49c1f 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -476,10 +476,13 @@ def main(args): # pylint: disable=redefined-outer-name model_disc = setup_discriminator(c) # setup optimizers - optimizer_gen = RAdam(model_gen.parameters(), lr=c.lr_gen, weight_decay=0) - optimizer_disc = RAdam(model_disc.parameters(), - lr=c.lr_disc, - weight_decay=0) + # TODO: allow loading custom optimizers + optimizer_gen = None + optimizer_disc = None + optimizer_gen = getattr(torch.optim, c.optimizer) + optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params) + optimizer_disc = getattr(torch.optim, c.optimizer) + optimizer_disc= optimizer_disc(lr=c.lr_gen, **c.optimizer_params) # schedulers scheduler_gen = None