diff --git a/TTS/bin/train_vocoder_gan.py b/TTS/bin/train_vocoder_gan.py index 000a545d..95cf612a 100644 --- a/TTS/bin/train_vocoder_gan.py +++ b/TTS/bin/train_vocoder_gan.py @@ -479,7 +479,7 @@ def main(args): # pylint: disable=redefined-outer-name optimizer_gen = getattr(torch.optim, c.optimizer) optimizer_gen = optimizer_gen(lr=c.lr_gen, **c.optimizer_params) optimizer_disc = getattr(torch.optim, c.optimizer) - optimizer_disc= optimizer_disc(lr=c.lr_gen, **c.optimizer_params) + optimizer_disc = optimizer_disc(lr=c.lr_gen, **c.optimizer_params) # schedulers scheduler_gen = None diff --git a/TTS/vocoder/datasets/gan_dataset.py b/TTS/vocoder/datasets/gan_dataset.py index 41744fd9..455ea95c 100644 --- a/TTS/vocoder/datasets/gan_dataset.py +++ b/TTS/vocoder/datasets/gan_dataset.py @@ -120,7 +120,7 @@ class GANDataset(Dataset): else: audio = self.ap.load_wav(wavpath) mel = np.load(feat_path) - audio, mel= self._pad_short_samples(audio, mel) + audio, mel = self._pad_short_samples(audio, mel) # correct the audio length wrt padding applied in stft audio = np.pad(audio, (0, self.hop_len), mode="edge") diff --git a/TTS/vocoder/layers/hifigan.py b/TTS/vocoder/layers/hifigan.py index 942c045d..ffd40588 100644 --- a/TTS/vocoder/layers/hifigan.py +++ b/TTS/vocoder/layers/hifigan.py @@ -56,4 +56,4 @@ class MRF(nn.Module): def remove_weight_norm(self): self.resblock1.remove_weight_norm() self.resblock2.remove_weight_norm() - self.resblock3.remove_weight_norm() \ No newline at end of file + self.resblock3.remove_weight_norm() diff --git a/TTS/vocoder/models/melgan_discriminator.py b/TTS/vocoder/models/melgan_discriminator.py index 8443a3b9..5e32d569 100644 --- a/TTS/vocoder/models/melgan_discriminator.py +++ b/TTS/vocoder/models/melgan_discriminator.py @@ -11,8 +11,7 @@ class MelganDiscriminator(nn.Module): base_channels=16, max_channels=1024, downsample_factors=(4, 4, 4, 4), - groups_denominator=4, - max_groups=256): + groups_denominator=4): super(MelganDiscriminator, self).__init__() self.layers = nn.ModuleList() diff --git a/tests/test_vocoder_losses.py b/tests/test_vocoder_losses.py index 765f67b3..2f38dd5a 100644 --- a/tests/test_vocoder_losses.py +++ b/tests/test_vocoder_losses.py @@ -89,4 +89,3 @@ def test_melgan_feature_loss(): loss_func = MelganFeatureLoss() loss = loss_func(feats_fake, feats_real) assert loss.item() == 0 -