From fc0c4600bdf6e69e090bfcb9befe811df18985bb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Eren=20G=C3=B6lge?= Date: Tue, 20 Jul 2021 17:34:42 +0200 Subject: [PATCH] Fix stopnet training --- TTS/tts/layers/losses.py | 15 ++++++++------- TTS/tts/models/base_tts.py | 4 +++- TTS/tts/models/tacotron.py | 2 ++ TTS/tts/models/tacotron2.py | 2 ++ TTS/tts/utils/data.py | 15 ++++++++++++--- tests/data_tests/test_loader.py | 2 +- 6 files changed, 28 insertions(+), 12 deletions(-) diff --git a/TTS/tts/layers/losses.py b/TTS/tts/layers/losses.py index 86d34c30..07b58974 100644 --- a/TTS/tts/layers/losses.py +++ b/TTS/tts/layers/losses.py @@ -246,9 +246,9 @@ class Huber(nn.Module): class TacotronLoss(torch.nn.Module): """Collection of Tacotron set-up based on provided config.""" - def __init__(self, c, stopnet_pos_weight=10, ga_sigma=0.4): + def __init__(self, c, ga_sigma=0.4): super().__init__() - self.stopnet_pos_weight = stopnet_pos_weight + self.stopnet_pos_weight = c.stopnet_pos_weight self.ga_alpha = c.ga_alpha self.decoder_diff_spec_alpha = c.decoder_diff_spec_alpha self.postnet_diff_spec_alpha = c.postnet_diff_spec_alpha @@ -274,7 +274,7 @@ class TacotronLoss(torch.nn.Module): self.criterion_ssim = SSIMLoss() # stopnet loss # pylint: disable=not-callable - self.criterion_st = BCELossMasked(pos_weight=torch.tensor(stopnet_pos_weight)) if c.stopnet else None + self.criterion_st = BCELossMasked(pos_weight=torch.tensor(self.stopnet_pos_weight)) if c.stopnet else None def forward( self, @@ -284,6 +284,7 @@ class TacotronLoss(torch.nn.Module): linear_input, stopnet_output, stopnet_target, + stop_target_length, output_lens, decoder_b_output, alignments, @@ -315,12 +316,12 @@ class TacotronLoss(torch.nn.Module): return_dict["decoder_loss"] = decoder_loss return_dict["postnet_loss"] = postnet_loss - # stopnet loss stop_loss = ( - self.criterion_st(stopnet_output, stopnet_target, output_lens) if self.config.stopnet else torch.zeros(1) + self.criterion_st(stopnet_output, stopnet_target, stop_target_length) + if self.config.stopnet + else torch.zeros(1) ) - if not self.config.separate_stopnet and self.config.stopnet: - loss += stop_loss + loss += stop_loss return_dict["stopnet_loss"] = stop_loss # backward decoder loss (if enabled) diff --git a/TTS/tts/models/base_tts.py b/TTS/tts/models/base_tts.py index 561b76fb..b36ed106 100644 --- a/TTS/tts/models/base_tts.py +++ b/TTS/tts/models/base_tts.py @@ -119,9 +119,10 @@ class BaseTTS(BaseModel): ), f" [!] total duration {dur.sum()} vs spectrogram length {mel_lengths[idx]}" durations[idx, : text_lengths[idx]] = dur - # set stop targets view, we predict a single stop token per iteration. + # set stop targets wrt reduction factor stop_targets = stop_targets.view(text_input.shape[0], stop_targets.size(1) // self.config.r, -1) stop_targets = (stop_targets.sum(2) > 0.0).unsqueeze(2).float().squeeze(2) + stop_target_lengths = torch.divide(mel_lengths, self.config.r).ceil_() return { "text_input": text_input, @@ -131,6 +132,7 @@ class BaseTTS(BaseModel): "mel_lengths": mel_lengths, "linear_input": linear_input, "stop_targets": stop_targets, + "stop_target_lengths": stop_target_lengths, "attn_mask": attn_mask, "durations": durations, "speaker_ids": speaker_ids, diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 95b4a358..7949ddf9 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -219,6 +219,7 @@ class Tacotron(BaseTacotron): mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] speaker_ids = batch["speaker_ids"] d_vectors = batch["d_vectors"] @@ -250,6 +251,7 @@ class Tacotron(BaseTacotron): linear_input, outputs["stop_tokens"], stop_targets, + stop_target_lengths, mel_lengths, outputs["decoder_outputs_backward"], outputs["alignments"], diff --git a/TTS/tts/models/tacotron2.py b/TTS/tts/models/tacotron2.py index eaca3ff8..19619662 100644 --- a/TTS/tts/models/tacotron2.py +++ b/TTS/tts/models/tacotron2.py @@ -224,6 +224,7 @@ class Tacotron2(BaseTacotron): mel_lengths = batch["mel_lengths"] linear_input = batch["linear_input"] stop_targets = batch["stop_targets"] + stop_target_lengths = batch["stop_target_lengths"] speaker_ids = batch["speaker_ids"] d_vectors = batch["d_vectors"] @@ -255,6 +256,7 @@ class Tacotron2(BaseTacotron): linear_input, outputs["stop_tokens"], stop_targets, + stop_target_lengths, mel_lengths, outputs["decoder_outputs_backward"], outputs["alignments"], diff --git a/TTS/tts/utils/data.py b/TTS/tts/utils/data.py index 3ff52195..887f4376 100644 --- a/TTS/tts/utils/data.py +++ b/TTS/tts/utils/data.py @@ -27,10 +27,19 @@ def prepare_tensor(inputs, out_steps): return np.stack([_pad_tensor(x, pad_len) for x in inputs]) -def _pad_stop_target(x, length): - _pad = 0.0 +def _pad_stop_target(x: np.ndarray, length: int, pad_val=1) -> np.ndarray: + """Pad stop target array. + + Args: + x (np.ndarray): Stop target array. + length (int): Length after padding. + pad_val (int, optional): Padding value. Defaults to 1. + + Returns: + np.ndarray: Padded stop target array. + """ assert x.ndim == 1 - return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=_pad) + return np.pad(x, (0, length - x.shape[0]), mode="constant", constant_values=pad_val) def prepare_stop_target(inputs, out_steps): diff --git a/tests/data_tests/test_loader.py b/tests/data_tests/test_loader.py index 9bc70ddd..3fd3eaef 100644 --- a/tests/data_tests/test_loader.py +++ b/tests/data_tests/test_loader.py @@ -207,7 +207,7 @@ class TestTTSDataset(unittest.TestCase): assert linear_input[1 - idx, -1].sum() == 0 assert mel_input[1 - idx, -1].sum() == 0 assert stop_target[1, mel_lengths[1] - 1] == 1 - assert stop_target[1, mel_lengths[1] :].sum() == 0 + assert stop_target[1, mel_lengths[1] :].sum() == stop_target.shape[1] - mel_lengths[1] assert len(mel_lengths.shape) == 1 # check batch zero-frame conditions (zero-frame disabled)