diff --git a/TTS/tts/layers/glow_tts/transformer.py b/TTS/tts/layers/glow_tts/transformer.py index 81160006..4feadc80 100644 --- a/TTS/tts/layers/glow_tts/transformer.py +++ b/TTS/tts/layers/glow_tts/transformer.py @@ -391,6 +391,10 @@ class RelativePositionTransformer(nn.Module): y = self.ffn_layers[i](x, x_mask) y = self.dropout(y) + + if (i + 1) == self.num_layers and hasattr(self, 'proj'): + x = self.proj(x) + x = self.norm_layers_2[i](x + y) x = x * x_mask return x diff --git a/TTS/tts/layers/speedy_speech/decoder.py b/TTS/tts/layers/speedy_speech/decoder.py index 9bbb047b..5ffb3339 100644 --- a/TTS/tts/layers/speedy_speech/decoder.py +++ b/TTS/tts/layers/speedy_speech/decoder.py @@ -1,9 +1,136 @@ +import torch from torch import nn -from TTS.tts.layers.generic.res_conv_bn import ConvBNBlock, ResidualConvBNBlock +from TTS.tts.layers.generic.res_conv_bn import Conv1dBNBlock, ResidualConv1dBNBlock, Conv1dBN from TTS.tts.layers.generic.wavenet import WNBlocks from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer +class WaveNetDecoder(nn.Module): + """WaveNet based decoder with a prenet and a postnet. + + prenet: conv1d_1x1 + postnet: 3 x [conv1d_1x1 -> relu] -> conv1d_1x1 + + TODO: Integrate speaker conditioning vector. + + Note: + default wavenet parameters; + params = { + "num_blocks": 12, + "hidden_channels":192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels for prenet and postnet. + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, hidden_channels, c_in_channels, params): + super().__init__() + # prenet + self.prenet = torch.nn.Conv1d(in_channels, params['hidden_channels'], 1) + # wavenet layers + self.wn = WNBlocks(params['hidden_channels'], c_in_channels=c_in_channels, **params) + # postnet + self.postnet = [ + torch.nn.Conv1d(params['hidden_channels'], hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, hidden_channels, 1), + torch.nn.ReLU(), + torch.nn.Conv1d(hidden_channels, out_channels, 1), + ] + self.postnet = nn.Sequential(*self.postnet) + + def forward(self, x, x_mask=None, g=None): + x = self.prenet(x) * x_mask + x = self.wn(x, x_mask, g) + o = self.postnet(x) * x_mask + return o + + +class RelativePositionTransformerDecoder(nn.Module): + """Decoder with Relative Positional Transformer. + + Note: + Default params + params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including Transformer layers. + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, hidden_channels, params): + + super().__init__() + self.prenet = Conv1dBN(in_channels, hidden_channels, 1, 1) + self.rel_pos_transformer = RelativePositionTransformer( + in_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class ResidualConv1dBNDecoder(nn.Module): + """Residual Convolutional Decoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Note: + Default params + params = { + "kernel_size": 4, + "dilations": 4 * [1, 2, 4, 8] + [1], + "num_conv_blocks": 2, + "num_res_blocks": 17 + } + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels including ResidualConv1dBNBlock layers. + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.res_conv_block = ResidualConv1dBNBlock(in_channels, + hidden_channels, + hidden_channels, **params) + self.post_conv = nn.Conv1d(hidden_channels, hidden_channels, 1) + self.postnet = nn.Sequential( + Conv1dBNBlock(hidden_channels, + hidden_channels, + hidden_channels, + params['kernel_size'], + 1, + num_conv_blocks=2), + nn.Conv1d(hidden_channels, out_channels, 1), + ) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + o = self.res_conv_block(x, x_mask) + o = self.post_conv(o) + x + return self.postnet(o) * x_mask + + class Decoder(nn.Module): """Decodes the expanded phoneme encoding into spectrograms Args: @@ -15,39 +142,8 @@ class Decoder(nn.Module): Shapes: - input: (B, C, T) - - Note: - Default decoder_params... - - for 'transformer' - decoder_params={ - 'hidden_channels_ffn': 128, - 'num_heads': 2, - "kernel_size": 3, - "dropout_p": 0.1, - "num_layers": 8, - "rel_attn_window_size": 4, - "input_length": None - }, - - for 'residual_conv_bn' - decoder_params = { - "kernel_size": 4, - "dilations": 4 * [1, 2, 4, 8] + [1], - "num_conv_blocks": 2, - "num_res_blocks": 17 - } - - for 'wavenet' - decoder_params = { - "num_blocks": 12, - "hidden_channels":192, - "kernel_size": 5, - "dilation_rate": 1, - "num_layers": 4, - "dropout_p": 0.05 - } """ + # pylint: disable=dangerous-default-value def __init__( self, @@ -62,28 +158,35 @@ class Decoder(nn.Module): }, c_in_channels=0): super().__init__() - self.in_channels = in_hidden_channels - self.hidden_channels = in_hidden_channels - self.out_channels = out_channels if decoder_type == 'transformer': - self.decoder = RelativePositionTransformer(self.hidden_channels, **decoder_params) + self.decoder = RelativePositionTransformerDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params) elif decoder_type == 'residual_conv_bn': - self.decoder = ResidualConvBNBlock(self.hidden_channels, - **decoder_params) + self.decoder = ResidualConv1dBNDecoder( + in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + params=decoder_params) elif decoder_type == 'wavenet': - self.decoder = WNBlocks(in_channels=self.in_channels, hidden_channels=self.hidden_channels, **decoder_params) + self.decoder = WaveNetDecoder(in_channels=in_hidden_channels, + out_channels=out_channels, + hidden_channels=in_hidden_channels, + c_in_channels=c_in_channels, + params=decoder_params) else: raise ValueError(f'[!] Unknown decoder type - {decoder_type}') - self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, 1) - self.post_net = nn.Sequential( - ConvBNBlock(self.hidden_channels, decoder_params['kernel_size'], 1, num_conv_blocks=2), - nn.Conv1d(self.hidden_channels, out_channels, 1), - ) - def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument + """ + Args: + x: [B, C, T] + x_mask: [B, 1, T] + g: [B, C_g, 1] + """ # TODO: implement multi-speaker - o = self.decoder(x, x_mask) - o = self.post_conv(o) + x - return self.post_net(o) * x_mask + o = self.decoder(x, x_mask, g) + return o \ No newline at end of file diff --git a/TTS/tts/layers/speedy_speech/duration_predictor.py b/TTS/tts/layers/speedy_speech/duration_predictor.py index 153a6a49..5c5c4f3a 100644 --- a/TTS/tts/layers/speedy_speech/duration_predictor.py +++ b/TTS/tts/layers/speedy_speech/duration_predictor.py @@ -1,6 +1,6 @@ from torch import nn -from TTS.tts.layers.generic.res_conv_bn import ConvBN +from TTS.tts.layers.generic.res_conv_bn import Conv1dBN class DurationPredictor(nn.Module): @@ -21,9 +21,9 @@ class DurationPredictor(nn.Module): super().__init__() self.layers = nn.ModuleList([ - ConvBN(hidden_channels, 4, 1), - ConvBN(hidden_channels, 3, 1), - ConvBN(hidden_channels, 1, 1), + Conv1dBN(hidden_channels, hidden_channels, 4, 1), + Conv1dBN(hidden_channels, hidden_channels, 3, 1), + Conv1dBN(hidden_channels, hidden_channels, 1, 1), nn.Conv1d(hidden_channels, 1, 1) ]) diff --git a/TTS/tts/layers/speedy_speech/encoder.py b/TTS/tts/layers/speedy_speech/encoder.py index 02468626..d26b306c 100644 --- a/TTS/tts/layers/speedy_speech/encoder.py +++ b/TTS/tts/layers/speedy_speech/encoder.py @@ -1,11 +1,10 @@ import math import torch from torch import nn -from torch.nn import functional as F from TTS.tts.layers.glow_tts.transformer import RelativePositionTransformer -from TTS.tts.layers.glow_tts.glow import ConvLayerNorm -from TTS.tts.layers.generic.res_conv_bn import ResidualConvBNBlock +from TTS.tts.layers.generic.res_conv_bn import ResidualConv1dBNBlock + class PositionalEncoding(nn.Module): @@ -18,12 +17,13 @@ class PositionalEncoding(nn.Module): def __init__(self, channels, dropout=0.0, max_len=5000): super().__init__() if channels % 2 != 0: - raise ValueError("Cannot use sin/cos positional encoding with " - "odd channels (got channels={:d})".format(channels)) + raise ValueError( + "Cannot use sin/cos positional encoding with " + "odd channels (got channels={:d})".format(channels)) pe = torch.zeros(max_len, channels) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp((torch.arange(0, channels, 2, dtype=torch.float) * - -(math.log(10000.0) / channels))) + -(math.log(10000.0) / channels))) pe[:, 0::2] = torch.sin(position.float() * div_term) pe[:, 1::2] = torch.cos(position.float() * div_term) pe = pe.unsqueeze(0).transpose(1, 2) @@ -59,9 +59,77 @@ class PositionalEncoding(nn.Module): return x +class RelativePositionTransformerEncoder(nn.Module): + """Speedy speech encoder built on Transformer with Relative Position encoding. + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = ResidualConv1dBNBlock(in_channels, + hidden_channels, + hidden_channels, + kernel_size=5, + num_res_blocks=3, + num_conv_blocks=1, + dilations=[1, 1, 1] + ) + self.rel_pos_transformer = RelativePositionTransformer( + hidden_channels, out_channels, hidden_channels, **params) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.rel_pos_transformer(o, x_mask) + return o + + +class ResidualConv1dBNEncoder(nn.Module): + """Residual Convolutional Encoder as in the original Speedy Speech paper + + TODO: Integrate speaker conditioning vector. + + Args: + in_channels (int): number of input channels. + out_channels (int): number of output channels. + hidden_channels (int): number of hidden channels + params (dict): dictionary for residual convolutional blocks. + """ + def __init__(self, in_channels, out_channels, hidden_channels, params): + super().__init__() + self.prenet = nn.Sequential( + nn.Conv1d(in_channels, hidden_channels, 1), + nn.ReLU()) + self.res_conv_block = ResidualConv1dBNBlock(hidden_channels, + hidden_channels, + hidden_channels, **params) + + self.postnet = nn.Sequential(*[ + nn.Conv1d(hidden_channels, hidden_channels, 1), + nn.ReLU(), + nn.BatchNorm1d(hidden_channels), + nn.Conv1d(hidden_channels, out_channels, 1) + ]) + + def forward(self, x, x_mask=None, g=None): # pylint: disable=unused-argument + if x_mask is None: + x_mask = 1 + o = self.prenet(x) * x_mask + o = self.res_conv_block(o, x_mask) + o = self.postnet(o + x) * x_mask + return o * x_mask + + class Encoder(nn.Module): # pylint: disable=dangerous-default-value - """Speedy-Speech encoder using Transformers or Residual BN Convs internally. + """Factory class for Speedy Speech encoder enables different encoder types internally. Args: num_chars (int): number of characters. @@ -114,29 +182,21 @@ class Encoder(nn.Module): # init encoder if encoder_type.lower() == "transformer": - # optional convolutional prenet - self.pre = ConvLayerNorm(self.in_channels, - self.hidden_channels, - self.hidden_channels, - kernel_size=5, - num_layers=3, - dropout_p=0.5) # text encoder - self.encoder = RelativePositionTransformer(self.hidden_channels, **encoder_params) # pylint: disable=unexpected-keyword-arg + self.encoder = RelativePositionTransformerEncoder(in_hidden_channels, + out_channels, + in_hidden_channels, + encoder_params) # pylint: disable=unexpected-keyword-arg elif encoder_type.lower() == 'residual_conv_bn': - self.pre = nn.Sequential( - nn.Conv1d(self.in_channels, self.hidden_channels, 1), - nn.ReLU()) - self.encoder = ResidualConvBNBlock(self.hidden_channels, - **encoder_params) + self.encoder = ResidualConv1dBNEncoder(in_hidden_channels, + out_channels, + in_hidden_channels, + encoder_params) else: - raise NotImplementedError(' [!] encoder type not implemented.') + raise NotImplementedError(' [!] unknown encoder type.') # final projection layers - self.post_conv = nn.Conv1d(self.hidden_channels, self.hidden_channels, - 1) - self.post_bn = nn.BatchNorm1d(self.hidden_channels) - self.post_conv2 = nn.Conv1d(self.hidden_channels, self.out_channels, 1) + def forward(self, x, x_mask, g=None): # pylint: disable=unused-argument """ @@ -145,15 +205,5 @@ class Encoder(nn.Module): x_mask: [B, 1, T] g: [B, C, 1] """ - # TODO: implement multi-speaker - if self.encoder_type == 'transformer': - o = self.pre(x, x_mask) - else: - o = self.pre(x) * x_mask - o = self.encoder(o, x_mask) - o = self.post_conv(o + x) - o = F.relu(o) - o = self.post_bn(o) - o = self.post_conv2(o) - # [B, C, T] + o = self.encoder(x, x_mask) return o * x_mask diff --git a/TTS/tts/layers/tacotron.py b/TTS/tts/layers/tacotron.py index 807282b3..c79edcc3 100644 --- a/TTS/tts/layers/tacotron.py +++ b/TTS/tts/layers/tacotron.py @@ -1,7 +1,8 @@ # coding: utf-8 import torch from torch import nn -from .common_layers import Prenet, init_attn +from .common_layers import Prenet +from .attentions import init_attn class BatchNormConv1d(nn.Module): diff --git a/TTS/tts/layers/tacotron2.py b/TTS/tts/layers/tacotron2.py index a02db784..8e6dbc15 100644 --- a/TTS/tts/layers/tacotron2.py +++ b/TTS/tts/layers/tacotron2.py @@ -1,7 +1,8 @@ import torch from torch import nn from torch.nn import functional as F -from .common_layers import init_attn, Prenet, Linear +from .common_layers import Prenet, Linear +from .attentions import init_attn # NOTE: linter has a problem with the current TF release #pylint: disable=no-value-for-parameter diff --git a/TTS/tts/models/tacotron.py b/TTS/tts/models/tacotron.py index 61eea893..0b68a96c 100644 --- a/TTS/tts/models/tacotron.py +++ b/TTS/tts/models/tacotron.py @@ -18,7 +18,7 @@ class Tacotron(TacotronAbstract): r (int): initial model reduction rate. postnet_output_dim (int, optional): postnet output channels. Defaults to 80. decoder_output_dim (int, optional): decoder output channels. Defaults to 80. - attn_type (str, optional): attention type. Check ```TTS.tts.layers.common_layers.init_attn```. Defaults to 'original'. + attn_type (str, optional): attention type. Check ```TTS.tts.layers.attentions.init_attn```. Defaults to 'original'. attn_win (bool, optional): enable/disable attention windowing. It especially useful at inference to keep attention alignment diagonal. Defaults to False. attn_norm (str, optional): Attention normalization method. "sigmoid" or "softmax". Defaults to "softmax". diff --git a/TTS/tts/tf/layers/tacotron2.py b/TTS/tts/tf/layers/tacotron2.py index 20d5f9a4..50a766a9 100644 --- a/TTS/tts/tf/layers/tacotron2.py +++ b/TTS/tts/tf/layers/tacotron2.py @@ -2,7 +2,7 @@ import tensorflow as tf from tensorflow import keras from TTS.tts.tf.utils.tf_utils import shape_list from TTS.tts.tf.layers.common_layers import Prenet, Attention -# from tensorflow_addons.seq2seq import AttentionWrapper + # NOTE: linter has a problem with the current TF release #pylint: disable=no-value-for-parameter diff --git a/tests/inputs/test_speedy_speech.json b/tests/inputs/test_speedy_speech.json index 2a4b3a45..ae4b8b2d 100644 --- a/tests/inputs/test_speedy_speech.json +++ b/tests/inputs/test_speedy_speech.json @@ -37,7 +37,7 @@ "symmetric_norm": true, // move normalization to range [-1, 1] "max_norm": 4.0, // scale normalization to range [-max_norm, max_norm] or [0, max_norm] "clip_norm": true, // clip normalized values into the range. - "stats_path": "/home/erogol/Data/LJSpeech-1.1/scale_stats.npy" // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored + "stats_path": null // DO NOT USE WITH MULTI_SPEAKER MODEL. scaler stats file computed by 'compute_statistics.py'. If it is defined, mean-std based notmalization is used and other normalization params are ignored }, // VOCABULARY PARAMETERS diff --git a/tests/test_speedy_speech_layers.py b/tests/test_speedy_speech_layers.py index 33a5e615..a5567ac3 100644 --- a/tests/test_speedy_speech_layers.py +++ b/tests/test_speedy_speech_layers.py @@ -50,10 +50,44 @@ def test_decoder(): input_mask = torch.unsqueeze( sequence_mask(input_lengths, input_dummy.size(2)), 1).to(device) + # residual bn conv decoder layer = Decoder(out_channels=11, in_hidden_channels=128).to(device) output = layer(input_dummy, input_mask) assert list(output.shape) == [8, 11, 37] + # transformer decoder + layer = Decoder(out_channels=11, + in_hidden_channels=128, + decoder_type='transformer', + decoder_params={ + 'hidden_channels_ffn': 128, + 'num_heads': 2, + "kernel_size": 3, + "dropout_p": 0.1, + "num_layers": 8, + "rel_attn_window_size": 4, + "input_length": None + }).to(device) + output = layer(input_dummy, input_mask) + assert list(output.shape) == [8, 11, 37] + + + # wavenet decoder + layer = Decoder(out_channels=11, + in_hidden_channels=128, + decoder_type='wavenet', + decoder_params={ + "num_blocks": 12, + "hidden_channels": 192, + "kernel_size": 5, + "dilation_rate": 1, + "num_layers": 4, + "dropout_p": 0.05 + }).to(device) + output = layer(input_dummy, input_mask) + assert list(output.shape) == [8, 11, 37] + + def test_duration_predictor(): input_dummy = torch.rand(8, 128, 27).to(device)