diff --git a/.gitignore b/.gitignore index c8a1c717..cc9ef477 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,7 @@ wheels/ .installed.cfg *.egg /package +/temp MANIFEST # PyInstaller @@ -123,3 +124,7 @@ replace.sh # Pytorch *.pth + + +# audio +*.wav diff --git a/docs/source/faq.md b/docs/source/faq.md index a93fafdc..6ed3b305 100644 --- a/docs/source/faq.md +++ b/docs/source/faq.md @@ -29,3 +29,15 @@ reference: [https://huggingface.co/docs/tokenizers/installation#installation-fro > ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts. 由于依赖库之间的版本不兼容,可能会存在版本冲突的情况,大部分情况下不影响正常运行。 + +### 3. 安装pytorch出现版本错误 + +> ERROR: Ignored the following versions that require a different python version: 1.1.0 Requires-Python >=3.8; 1.1.0rc1 Requires-Python >=3.8; 1.1.1 Requires-Python >=3.8 +> ERROR: Could not find a version that satisfies the requirement torch==1.8.1+cu111 (from versions: 1.0.0, 1.0.1, 1.0.1.post2, 1.1.0, 1.2.0, 1.3.0, 1.3.1, 1.4.0, 1.5.0, 1.5.1, 1.6.0, 1.7.0, 1.7.1, 1.8.0, 1.8.1, 1.9.0, 1.9.1, 1.10.0, 1.10.1, 1.10.2, 1.11.0) +> ERROR: No matching distribution found for torch==1.8.1+cu111 + +安装时使用如下命令: + +```shell +pip install -f https://download.pytorch.org/whl/torch_stable.html -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt +``` diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index 0f4cbbc3..7148f27f 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -25,6 +25,10 @@ ModelScope Library目前支持tensorflow,pytorch两大深度学习框架进行 * [Pytorch安装指导](https://pytorch.org/get-started/locally/) * [Tensorflow安装指导](https://www.tensorflow.org/install/pip) +部分第三方依赖库需要提前安装numpy +``` +pip install numpy +``` ## ModelScope library 安装 diff --git a/modelscope/models/__init__.py b/modelscope/models/__init__.py index d9a89d35..7d70e6ca 100644 --- a/modelscope/models/__init__.py +++ b/modelscope/models/__init__.py @@ -1,5 +1,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. +from .audio.tts.am import SambertNetHifi16k +from .audio.tts.vocoder import Hifigan16k from .base import Model from .builder import MODELS, build_model from .nlp import BertForSequenceClassification, SbertForSentenceSimilarity diff --git a/modelscope/models/audio/tts/__init__.py b/modelscope/models/audio/tts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/am/__init__.py b/modelscope/models/audio/tts/am/__init__.py new file mode 100644 index 00000000..2ebbda1c --- /dev/null +++ b/modelscope/models/audio/tts/am/__init__.py @@ -0,0 +1 @@ +from .sambert_hifi_16k import * # noqa F403 diff --git a/modelscope/models/audio/tts/am/models/__init__.py b/modelscope/models/audio/tts/am/models/__init__.py new file mode 100755 index 00000000..9e198e7a --- /dev/null +++ b/modelscope/models/audio/tts/am/models/__init__.py @@ -0,0 +1,8 @@ +from .robutrans import RobuTrans + + +def create_model(name, hparams): + if name == 'robutrans': + return RobuTrans(hparams) + else: + raise Exception('Unknown model: ' + name) diff --git a/modelscope/models/audio/tts/am/models/compat.py b/modelscope/models/audio/tts/am/models/compat.py new file mode 100755 index 00000000..bb810841 --- /dev/null +++ b/modelscope/models/audio/tts/am/models/compat.py @@ -0,0 +1,82 @@ +"""Functions for compatibility with different TensorFlow versions.""" + +import tensorflow as tf + + +def is_tf2(): + """Returns ``True`` if running TensorFlow 2.0.""" + return tf.__version__.startswith('2') + + +def tf_supports(symbol): + """Returns ``True`` if TensorFlow defines :obj:`symbol`.""" + return _string_to_tf_symbol(symbol) is not None + + +def tf_any(*symbols): + """Returns the first supported symbol.""" + for symbol in symbols: + module = _string_to_tf_symbol(symbol) + if module is not None: + return module + return None + + +def tf_compat(v2=None, v1=None): # pylint: disable=invalid-name + """Returns the compatible symbol based on the current TensorFlow version. + + Args: + v2: The candidate v2 symbol name. + v1: The candidate v1 symbol name. + + Returns: + A TensorFlow symbol. + + Raises: + ValueError: if no symbol can be found. + """ + candidates = [] + if v2 is not None: + candidates.append(v2) + if v1 is not None: + candidates.append(v1) + candidates.append('compat.v1.%s' % v1) + symbol = tf_any(*candidates) + if symbol is None: + raise ValueError('Failure to resolve the TensorFlow symbol') + return symbol + + +def name_from_variable_scope(name=''): + """Creates a name prefixed by the current variable scope.""" + var_scope = tf_compat(v1='get_variable_scope')().name + compat_name = '' + if name: + compat_name = '%s/' % name + if var_scope: + compat_name = '%s/%s' % (var_scope, compat_name) + return compat_name + + +def reuse(): + """Returns ``True`` if the current variable scope is marked for reuse.""" + return tf_compat(v1='get_variable_scope')().reuse + + +def _string_to_tf_symbol(symbol): + modules = symbol.split('.') + namespace = tf + for module in modules: + namespace = getattr(namespace, module, None) + if namespace is None: + return None + return namespace + + +# pylint: disable=invalid-name +gfile_copy = tf_compat(v2='io.gfile.copy', v1='gfile.Copy') +gfile_exists = tf_compat(v2='io.gfile.exists', v1='gfile.Exists') +gfile_open = tf_compat(v2='io.gfile.GFile', v1='gfile.GFile') +is_tensor = tf_compat(v2='is_tensor', v1='contrib.framework.is_tensor') +logging = tf_compat(v1='logging') +nest = tf_compat(v2='nest', v1='contrib.framework.nest') diff --git a/modelscope/models/audio/tts/am/models/fsmn.py b/modelscope/models/audio/tts/am/models/fsmn.py new file mode 100755 index 00000000..875c27f0 --- /dev/null +++ b/modelscope/models/audio/tts/am/models/fsmn.py @@ -0,0 +1,273 @@ +import tensorflow as tf + + +def build_sequence_mask(sequence_length, + maximum_length=None, + dtype=tf.float32): + """Builds the dot product mask. + + Args: + sequence_length: The sequence length. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, max_length]``. + """ + mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + + return mask + + +def norm(inputs): + """Layer normalizes :obj:`inputs`.""" + return tf.contrib.layers.layer_norm(inputs, begin_norm_axis=-1) + + +def pad_in_time(x, padding_shape): + """Helper function to pad a tensor in the time dimension and retain the static depth dimension. + + Agrs: + x: [Batch, Time, Frequency] + padding_length: padding size of constant value (0) before the time dimension + + return: + padded x + """ + + depth = x.get_shape().as_list()[-1] + x = tf.pad(x, [[0, 0], padding_shape, [0, 0]]) + x.set_shape((None, None, depth)) + + return x + + +def pad_in_time_right(x, padding_length): + """Helper function to pad a tensor in the time dimension and retain the static depth dimension. + + Agrs: + x: [Batch, Time, Frequency] + padding_length: padding size of constant value (0) before the time dimension + + return: + padded x + """ + depth = x.get_shape().as_list()[-1] + x = tf.pad(x, [[0, 0], [0, padding_length], [0, 0]]) + x.set_shape((None, None, depth)) + + return x + + +def feed_forward(x, ffn_dim, memory_units, mode, dropout=0.0): + """Implements the Transformer's "Feed Forward" layer. + + .. math:: + + ffn(x) = max(0, x*W_1 + b_1)*W_2 + + Args: + x: The input. + ffn_dim: The number of units of the nonlinear transformation. + memory_units: the number of units of linear transformation + mode: A ``tf.estimator.ModeKeys`` mode. + dropout: The probability to drop units from the inner transformation. + + Returns: + The transformed input. + """ + inner = tf.layers.conv1d(x, ffn_dim, 1, activation=tf.nn.relu) + inner = tf.layers.dropout( + inner, rate=dropout, training=mode == tf.estimator.ModeKeys.TRAIN) + outer = tf.layers.conv1d(inner, memory_units, 1, use_bias=False) + + return outer + + +def drop_and_add(inputs, outputs, mode, dropout=0.0): + """Drops units in the outputs and adds the previous values. + + Args: + inputs: The input of the previous layer. + outputs: The output of the previous layer. + mode: A ``tf.estimator.ModeKeys`` mode. + dropout: The probability to drop units in :obj:`outputs`. + + Returns: + The residual and normalized output. + """ + outputs = tf.layers.dropout(outputs, rate=dropout, training=mode) + + input_dim = inputs.get_shape().as_list()[-1] + output_dim = outputs.get_shape().as_list()[-1] + + if input_dim == output_dim: + outputs += inputs + + return outputs + + +def MemoryBlock( + inputs, + filter_size, + mode, + mask=None, + dropout=0.0, +): + """ + Define the bidirectional memory block in FSMN + + Agrs: + inputs: The output of the previous layer. [Batch, Time, Frequency] + filter_size: memory block filter size + mode: Training or Evaluation + mask: A ``tf.Tensor`` applied to the memory block output + + return: + output: 3-D tensor ([Batch, Time, Frequency]) + """ + static_shape = inputs.get_shape().as_list() + depth = static_shape[-1] + inputs = tf.expand_dims(inputs, axis=1) # [Batch, 1, Time, Frequency] + depthwise_filter = tf.get_variable( + 'depth_conv_w', + shape=[1, filter_size, depth, 1], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + memory = tf.nn.depthwise_conv2d( + input=inputs, + filter=depthwise_filter, + strides=[1, 1, 1, 1], + padding='SAME', + rate=[1, 1], + data_format='NHWC') + memory = memory + inputs + output = tf.layers.dropout(memory, rate=dropout, training=mode) + output = tf.reshape( + output, + [tf.shape(output)[0], tf.shape(output)[2], depth]) + if mask is not None: + output = output * tf.expand_dims(mask, -1) + + return output + + +def MemoryBlockV2( + inputs, + filter_size, + mode, + shift=0, + mask=None, + dropout=0.0, +): + """ + Define the bidirectional memory block in FSMN + + Agrs: + inputs: The output of the previous layer. [Batch, Time, Frequency] + filter_size: memory block filter size + mode: Training or Evaluation + shift: left padding, to control delay + mask: A ``tf.Tensor`` applied to the memory block output + + return: + output: 3-D tensor ([Batch, Time, Frequency]) + """ + if mask is not None: + inputs = inputs * tf.expand_dims(mask, -1) + + static_shape = inputs.get_shape().as_list() + depth = static_shape[-1] + # padding + left_padding = int(round((filter_size - 1) / 2)) + right_padding = int((filter_size - 1) / 2) + if shift > 0: + left_padding = left_padding + shift + right_padding = right_padding - shift + pad_inputs = pad_in_time(inputs, [left_padding, right_padding]) + pad_inputs = tf.expand_dims( + pad_inputs, axis=1) # [Batch, 1, Time, Frequency] + depthwise_filter = tf.get_variable( + 'depth_conv_w', + shape=[1, filter_size, depth, 1], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + memory = tf.nn.depthwise_conv2d( + input=pad_inputs, + filter=depthwise_filter, + strides=[1, 1, 1, 1], + padding='VALID', + rate=[1, 1], + data_format='NHWC') + memory = tf.reshape( + memory, + [tf.shape(memory)[0], tf.shape(memory)[2], depth]) + memory = memory + inputs + output = tf.layers.dropout(memory, rate=dropout, training=mode) + if mask is not None: + output = output * tf.expand_dims(mask, -1) + + return output + + +def UniMemoryBlock( + inputs, + filter_size, + mode, + cache=None, + mask=None, + dropout=0.0, +): + """ + Define the unidirectional memory block in FSMN + + Agrs: + inputs: The output of the previous layer. [Batch, Time, Frequency] + filter_size: memory block filter size + cache: for streaming inference + mode: Training or Evaluation + mask: A ``tf.Tensor`` applied to the memory block output + dropout: dorpout factor + return: + output: 3-D tensor ([Batch, Time, Frequency]) + """ + if cache is not None: + static_shape = cache['queries'].get_shape().as_list() + depth = static_shape[-1] + queries = tf.slice(cache['queries'], [0, 1, 0], [ + tf.shape(cache['queries'])[0], + tf.shape(cache['queries'])[1] - 1, depth + ]) + queries = tf.concat([queries, inputs], axis=1) + cache['queries'] = queries + else: + padding_length = filter_size - 1 + queries = pad_in_time(inputs, [padding_length, 0]) + + queries = tf.expand_dims(queries, axis=1) # [Batch, 1, Time, Frequency] + static_shape = queries.get_shape().as_list() + depth = static_shape[-1] + depthwise_filter = tf.get_variable( + 'depth_conv_w', + shape=[1, filter_size, depth, 1], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + memory = tf.nn.depthwise_conv2d( + input=queries, + filter=depthwise_filter, + strides=[1, 1, 1, 1], + padding='VALID', + rate=[1, 1], + data_format='NHWC') + memory = tf.reshape( + memory, + [tf.shape(memory)[0], tf.shape(memory)[2], depth]) + memory = memory + inputs + output = tf.layers.dropout(memory, rate=dropout, training=mode) + if mask is not None: + output = output * tf.expand_dims(mask, -1) + + return output diff --git a/modelscope/models/audio/tts/am/models/fsmn_encoder.py b/modelscope/models/audio/tts/am/models/fsmn_encoder.py new file mode 100755 index 00000000..2c650624 --- /dev/null +++ b/modelscope/models/audio/tts/am/models/fsmn_encoder.py @@ -0,0 +1,178 @@ +import tensorflow as tf + +from . import fsmn + + +class FsmnEncoder(): + """Encoder using Fsmn + """ + + def __init__(self, + filter_size, + fsmn_num_layers, + dnn_num_layers, + num_memory_units=512, + ffn_inner_dim=2048, + dropout=0.0, + position_encoder=None): + """Initializes the parameters of the encoder. + + Args: + filter_size: the total order of memory block + fsmn_num_layers: The number of fsmn layers. + dnn_num_layers: The number of dnn layers + num_units: The number of memory units. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + """ + super(FsmnEncoder, self).__init__() + self.filter_size = filter_size + self.fsmn_num_layers = fsmn_num_layers + self.dnn_num_layers = dnn_num_layers + self.num_memory_units = num_memory_units + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.position_encoder = position_encoder + + def encode(self, inputs, sequence_length=None, mode=True): + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + + mask = fsmn.build_sequence_mask( + sequence_length, maximum_length=tf.shape(inputs)[1]) + + state = () + + for layer in range(self.fsmn_num_layers): + with tf.variable_scope('fsmn_layer_{}'.format(layer)): + with tf.variable_scope('ffn'): + context = fsmn.feed_forward( + inputs, + self.ffn_inner_dim, + self.num_memory_units, + mode, + dropout=self.dropout) + + with tf.variable_scope('memory'): + memory = fsmn.MemoryBlock( + context, + self.filter_size, + mode, + mask=mask, + dropout=self.dropout) + + memory = fsmn.drop_and_add( + inputs, memory, mode, dropout=self.dropout) + + inputs = memory + state += (tf.reduce_mean(inputs, axis=1), ) + + for layer in range(self.dnn_num_layers): + with tf.variable_scope('dnn_layer_{}'.format(layer)): + transformed = fsmn.feed_forward( + inputs, + self.ffn_inner_dim, + self.num_memory_units, + mode, + dropout=self.dropout) + + inputs = transformed + state += (tf.reduce_mean(inputs, axis=1), ) + + outputs = inputs + return (outputs, state, sequence_length) + + +class FsmnEncoderV2(): + """Encoder using Fsmn + """ + + def __init__(self, + filter_size, + fsmn_num_layers, + dnn_num_layers, + num_memory_units=512, + ffn_inner_dim=2048, + dropout=0.0, + shift=0, + position_encoder=None): + """Initializes the parameters of the encoder. + + Args: + filter_size: the total order of memory block + fsmn_num_layers: The number of fsmn layers. + dnn_num_layers: The number of dnn layers + num_units: The number of memory units. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + shift: left padding, to control delay + position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + """ + super(FsmnEncoderV2, self).__init__() + self.filter_size = filter_size + self.fsmn_num_layers = fsmn_num_layers + self.dnn_num_layers = dnn_num_layers + self.num_memory_units = num_memory_units + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.shift = shift + if not isinstance(shift, list): + self.shift = [shift for _ in range(self.fsmn_num_layers)] + self.position_encoder = position_encoder + + def encode(self, inputs, sequence_length=None, mode=True): + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + + mask = fsmn.build_sequence_mask( + sequence_length, maximum_length=tf.shape(inputs)[1]) + + state = () + for layer in range(self.fsmn_num_layers): + with tf.variable_scope('fsmn_layer_{}'.format(layer)): + with tf.variable_scope('ffn'): + context = fsmn.feed_forward( + inputs, + self.ffn_inner_dim, + self.num_memory_units, + mode, + dropout=self.dropout) + + with tf.variable_scope('memory'): + memory = fsmn.MemoryBlockV2( + context, + self.filter_size, + mode, + shift=self.shift[layer], + mask=mask, + dropout=self.dropout) + + memory = fsmn.drop_and_add( + inputs, memory, mode, dropout=self.dropout) + + inputs = memory + state += (tf.reduce_mean(inputs, axis=1), ) + + for layer in range(self.dnn_num_layers): + with tf.variable_scope('dnn_layer_{}'.format(layer)): + transformed = fsmn.feed_forward( + inputs, + self.ffn_inner_dim, + self.num_memory_units, + mode, + dropout=self.dropout) + + inputs = transformed + state += (tf.reduce_mean(inputs, axis=1), ) + + outputs = inputs + return (outputs, state, sequence_length) diff --git a/modelscope/models/audio/tts/am/models/helpers.py b/modelscope/models/audio/tts/am/models/helpers.py new file mode 100755 index 00000000..f3e53277 --- /dev/null +++ b/modelscope/models/audio/tts/am/models/helpers.py @@ -0,0 +1,160 @@ +import numpy as np +import tensorflow as tf +from tensorflow.contrib.seq2seq import Helper + + +class VarTestHelper(Helper): + + def __init__(self, batch_size, inputs, dim): + with tf.name_scope('VarTestHelper'): + self._batch_size = batch_size + self._inputs = inputs + self._dim = dim + + num_steps = tf.shape(self._inputs)[1] + self._lengths = tf.tile([num_steps], [self._batch_size]) + + self._inputs = tf.roll(inputs, shift=-1, axis=1) + self._init_inputs = inputs[:, 0, :] + + @property + def batch_size(self): + return self._batch_size + + @property + def sample_ids_shape(self): + return tf.TensorShape([]) + + @property + def sample_ids_dtype(self): + return np.int32 + + def initialize(self, name=None): + return (tf.tile([False], [self._batch_size]), + _go_frames(self._batch_size, self._dim, self._init_inputs)) + + def sample(self, time, outputs, state, name=None): + return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with tf.name_scope('VarTestHelper'): + finished = (time + 1 >= self._lengths) + next_inputs = tf.concat([outputs, self._inputs[:, time, :]], + axis=-1) + return (finished, next_inputs, state) + + +class VarTrainingHelper(Helper): + + def __init__(self, targets, inputs, dim): + with tf.name_scope('VarTrainingHelper'): + self._targets = targets # [N, T_in, 1] + self._batch_size = tf.shape(inputs)[0] # N + self._inputs = inputs + self._dim = dim + + num_steps = tf.shape(self._targets)[1] + self._lengths = tf.tile([num_steps], [self._batch_size]) + + self._inputs = tf.roll(inputs, shift=-1, axis=1) + self._init_inputs = inputs[:, 0, :] + + @property + def batch_size(self): + return self._batch_size + + @property + def sample_ids_shape(self): + return tf.TensorShape([]) + + @property + def sample_ids_dtype(self): + return np.int32 + + def initialize(self, name=None): + return (tf.tile([False], [self._batch_size]), + _go_frames(self._batch_size, self._dim, self._init_inputs)) + + def sample(self, time, outputs, state, name=None): + return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with tf.name_scope(name or 'VarTrainingHelper'): + finished = (time + 1 >= self._lengths) + next_inputs = tf.concat( + [self._targets[:, time, :], self._inputs[:, time, :]], axis=-1) + return (finished, next_inputs, state) + + +class VarTrainingSSHelper(Helper): + + def __init__(self, targets, inputs, dim, global_step, schedule_begin, + alpha, decay_steps): + with tf.name_scope('VarTrainingSSHelper'): + self._targets = targets # [N, T_in, 1] + self._batch_size = tf.shape(inputs)[0] # N + self._inputs = inputs + self._dim = dim + + num_steps = tf.shape(self._targets)[1] + self._lengths = tf.tile([num_steps], [self._batch_size]) + + self._inputs = tf.roll(inputs, shift=-1, axis=1) + self._init_inputs = inputs[:, 0, :] + + # for schedule sampling + self._global_step = global_step + self._schedule_begin = schedule_begin + self._alpha = alpha + self._decay_steps = decay_steps + + @property + def batch_size(self): + return self._batch_size + + @property + def sample_ids_shape(self): + return tf.TensorShape([]) + + @property + def sample_ids_dtype(self): + return np.int32 + + def initialize(self, name=None): + self._ratio = _tf_decay(self._global_step, self._schedule_begin, + self._alpha, self._decay_steps) + return (tf.tile([False], [self._batch_size]), + _go_frames(self._batch_size, self._dim, self._init_inputs)) + + def sample(self, time, outputs, state, name=None): + return tf.tile([0], [self._batch_size]) # Return all 0; we ignore them + + def next_inputs(self, time, outputs, state, sample_ids, name=None): + with tf.name_scope(name or 'VarTrainingHelper'): + finished = (time + 1 >= self._lengths) + next_inputs_tmp = tf.cond( + tf.less( + tf.random_uniform([], minval=0, maxval=1, + dtype=tf.float32), self._ratio), + lambda: self._targets[:, time, :], lambda: outputs) + next_inputs = tf.concat( + [next_inputs_tmp, self._inputs[:, time, :]], axis=-1) + return (finished, next_inputs, state) + + +def _go_frames(batch_size, dim, init_inputs): + '''Returns all-zero frames for a given batch size and output dimension''' + return tf.concat([tf.tile([[0.0]], [batch_size, dim]), init_inputs], + axis=-1) + + +def _tf_decay(global_step, schedule_begin, alpha, decay_steps): + tfr = tf.train.exponential_decay( + 1.0, + global_step=global_step - schedule_begin, + decay_steps=decay_steps, + decay_rate=alpha, + name='tfr_decay') + final_tfr = tf.cond( + tf.less(global_step, schedule_begin), lambda: 1.0, lambda: tfr) + return final_tfr diff --git a/modelscope/models/audio/tts/am/models/modules.py b/modelscope/models/audio/tts/am/models/modules.py new file mode 100755 index 00000000..1433fd7e --- /dev/null +++ b/modelscope/models/audio/tts/am/models/modules.py @@ -0,0 +1,461 @@ +import tensorflow as tf +from tensorflow.contrib.cudnn_rnn import CudnnLSTM +from tensorflow.contrib.cudnn_rnn.python.ops import cudnn_rnn_ops +from tensorflow.contrib.rnn import LSTMBlockCell + + +def encoder_prenet(inputs, + n_conv_layers, + filters, + kernel_size, + dense_units, + is_training, + mask=None, + scope='encoder_prenet'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + mask=mask, + scope='conv1d_{}'.format(i)) + x = tf.layers.dense( + x, units=dense_units, activation=None, name='dense') + return x + + +def decoder_prenet(inputs, + prenet_units, + dense_units, + is_training, + scope='decoder_prenet'): + x = inputs + with tf.variable_scope(scope): + for i, units in enumerate(prenet_units): + x = tf.layers.dense( + x, + units=units, + activation=tf.nn.relu, + name='dense_{}'.format(i)) + x = tf.layers.dropout( + x, rate=0.5, training=is_training, name='dropout_{}'.format(i)) + x = tf.layers.dense( + x, units=dense_units, activation=None, name='dense') + return x + + +def encoder(inputs, + input_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker, + mask=None, + scope='encoder'): + with tf.variable_scope(scope): + x = conv_and_lstm( + inputs, + input_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker, + mask=mask) + return x + + +def prenet(inputs, prenet_units, is_training, scope='prenet'): + x = inputs + with tf.variable_scope(scope): + for i, units in enumerate(prenet_units): + x = tf.layers.dense( + x, + units=units, + activation=tf.nn.relu, + name='dense_{}'.format(i)) + x = tf.layers.dropout( + x, rate=0.5, training=is_training, name='dropout_{}'.format(i)) + return x + + +def postnet_residual_ulstm(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + output_units, + is_training, + scope='postnet_residual_ulstm'): + with tf.variable_scope(scope): + x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size, + lstm_units, is_training) + x = conv1d( + x, + output_units, + kernel_size, + is_training, + activation=None, + dropout=False, + scope='conv1d_{}'.format(n_conv_layers - 1)) + return x + + +def postnet_residual_lstm(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + output_units, + is_training, + scope='postnet_residual_lstm'): + with tf.variable_scope(scope): + x = conv_and_lstm(inputs, None, n_conv_layers, filters, kernel_size, + lstm_units, is_training) + x = conv1d( + x, + output_units, + kernel_size, + is_training, + activation=None, + dropout=False, + scope='conv1d_{}'.format(n_conv_layers - 1)) + return x + + +def postnet_linear_ulstm(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + output_units, + is_training, + scope='postnet_linear'): + with tf.variable_scope(scope): + x = conv_and_ulstm(inputs, None, n_conv_layers, filters, kernel_size, + lstm_units, is_training) + x = tf.layers.dense(x, units=output_units) + return x + + +def postnet_linear_lstm(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + output_units, + output_lengths, + is_training, + embedded_inputs_speaker2, + mask=None, + scope='postnet_linear'): + with tf.variable_scope(scope): + x = conv_and_lstm_dec( + inputs, + output_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker2, + mask=mask) + x = tf.layers.dense(x, units=output_units) + return x + + +def postnet_linear(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + output_units, + output_lengths, + is_training, + embedded_inputs_speaker2, + mask=None, + scope='postnet_linear'): + with tf.variable_scope(scope): + x = conv_dec( + inputs, + output_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker2, + mask=mask) + return x + + +def conv_and_lstm(inputs, + sequence_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker, + mask=None, + scope='conv_and_lstm'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + mask=mask, + scope='conv1d_{}'.format(i)) + + x = tf.concat([x, embedded_inputs_speaker], axis=2) + + outputs, states = tf.nn.bidirectional_dynamic_rnn( + LSTMBlockCell(lstm_units), + LSTMBlockCell(lstm_units), + x, + sequence_length=sequence_lengths, + dtype=tf.float32) + x = tf.concat(outputs, axis=-1) + + return x + + +def conv_and_lstm_dec(inputs, + sequence_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker2, + mask=None, + scope='conv_and_lstm'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + mask=mask, + scope='conv1d_{}'.format(i)) + + x = tf.concat([x, embedded_inputs_speaker2], axis=2) + + outputs, states = tf.nn.bidirectional_dynamic_rnn( + LSTMBlockCell(lstm_units), + LSTMBlockCell(lstm_units), + x, + sequence_length=sequence_lengths, + dtype=tf.float32) + x = tf.concat(outputs, axis=-1) + return x + + +def conv_dec(inputs, + sequence_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + embedded_inputs_speaker2, + mask=None, + scope='conv_and_lstm'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + mask=mask, + scope='conv1d_{}'.format(i)) + x = tf.concat([x, embedded_inputs_speaker2], axis=2) + return x + + +def conv_and_ulstm(inputs, + sequence_lengths, + n_conv_layers, + filters, + kernel_size, + lstm_units, + is_training, + scope='conv_and_ulstm'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + scope='conv1d_{}'.format(i)) + + outputs, states = tf.nn.dynamic_rnn( + LSTMBlockCell(lstm_units), + x, + sequence_length=sequence_lengths, + dtype=tf.float32) + + return outputs + + +def conv1d(inputs, + filters, + kernel_size, + is_training, + activation=None, + dropout=False, + mask=None, + scope='conv1d'): + with tf.variable_scope(scope): + if mask is not None: + inputs = inputs * tf.expand_dims(mask, -1) + x = tf.layers.conv1d( + inputs, filters=filters, kernel_size=kernel_size, padding='same') + if mask is not None: + x = x * tf.expand_dims(mask, -1) + + x = tf.layers.batch_normalization(x, training=is_training) + if activation is not None: + x = activation(x) + if dropout: + x = tf.layers.dropout(x, rate=0.5, training=is_training) + return x + + +def conv1d_dp(inputs, + filters, + kernel_size, + is_training, + activation=None, + dropout=False, + dropoutrate=0.5, + mask=None, + scope='conv1d'): + with tf.variable_scope(scope): + if mask is not None: + inputs = inputs * tf.expand_dims(mask, -1) + x = tf.layers.conv1d( + inputs, filters=filters, kernel_size=kernel_size, padding='same') + if mask is not None: + x = x * tf.expand_dims(mask, -1) + + x = tf.contrib.layers.layer_norm(x) + if activation is not None: + x = activation(x) + if dropout: + x = tf.layers.dropout(x, rate=dropoutrate, training=is_training) + return x + + +def duration_predictor(inputs, + n_conv_layers, + filters, + kernel_size, + lstm_units, + input_lengths, + is_training, + embedded_inputs_speaker, + mask=None, + scope='duration_predictor'): + with tf.variable_scope(scope): + x = inputs + for i in range(n_conv_layers): + x = conv1d_dp( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + dropoutrate=0.1, + mask=mask, + scope='conv1d_{}'.format(i)) + + x = tf.concat([x, embedded_inputs_speaker], axis=2) + + outputs, states = tf.nn.bidirectional_dynamic_rnn( + LSTMBlockCell(lstm_units), + LSTMBlockCell(lstm_units), + x, + sequence_length=input_lengths, + dtype=tf.float32) + x = tf.concat(outputs, axis=-1) + + x = tf.layers.dense(x, units=1) + x = tf.nn.relu(x) + return x + + +def duration_predictor2(inputs, + n_conv_layers, + filters, + kernel_size, + input_lengths, + is_training, + mask=None, + scope='duration_predictor'): + with tf.variable_scope(scope): + x = inputs + for i in range(n_conv_layers): + x = conv1d_dp( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + dropoutrate=0.1, + mask=mask, + scope='conv1d_{}'.format(i)) + + x = tf.layers.dense(x, units=1) + x = tf.nn.relu(x) + return x + + +def conv_prenet(inputs, + n_conv_layers, + filters, + kernel_size, + is_training, + mask=None, + scope='conv_prenet'): + x = inputs + with tf.variable_scope(scope): + for i in range(n_conv_layers): + x = conv1d( + x, + filters, + kernel_size, + is_training, + activation=tf.nn.relu, + dropout=True, + mask=mask, + scope='conv1d_{}'.format(i)) + + return x diff --git a/modelscope/models/audio/tts/am/models/position.py b/modelscope/models/audio/tts/am/models/position.py new file mode 100755 index 00000000..bca658dd --- /dev/null +++ b/modelscope/models/audio/tts/am/models/position.py @@ -0,0 +1,174 @@ +"""Define position encoder classes.""" + +import abc +import math + +import tensorflow as tf + +from .reducer import SumReducer + + +class PositionEncoder(tf.keras.layers.Layer): + """Base class for position encoders.""" + + def __init__(self, reducer=None, **kwargs): + """Initializes the position encoder. + Args: + reducer: A :class:`opennmt.layers.Reducer` to merge inputs and position + encodings. Defaults to :class:`opennmt.layers.SumReducer`. + **kwargs: Additional layer keyword arguments. + """ + super(PositionEncoder, self).__init__(**kwargs) + if reducer is None: + reducer = SumReducer(dtype=kwargs.get('dtype')) + self.reducer = reducer + + def call(self, inputs, position=None): # pylint: disable=arguments-differ + """Add position encodings to :obj:`inputs`. + Args: + inputs: The inputs to encode. + position: The single position to encode, to use when this layer is called + step by step. + Returns: + A ``tf.Tensor`` whose shape depends on the configured ``reducer``. + """ + batch_size = tf.shape(inputs)[0] + timesteps = tf.shape(inputs)[1] + input_dim = inputs.shape[-1].value + positions = tf.range(timesteps) + 1 if position is None else [position] + position_encoding = self._encode([positions], input_dim) + position_encoding = tf.tile(position_encoding, [batch_size, 1, 1]) + return self.reducer([inputs, position_encoding]) + + @abc.abstractmethod + def _encode(self, positions, depth): + """Creates position encodings. + Args: + positions: The positions to encode of shape :math:`[B, ...]`. + depth: The encoding depth :math:`D`. + Returns: + A ``tf.Tensor`` of shape :math:`[B, ..., D]`. + """ + raise NotImplementedError() + + +class PositionEmbedder(PositionEncoder): + """Encodes position with a lookup table.""" + + def __init__(self, maximum_position=128, reducer=None, **kwargs): + """Initializes the position encoder. + Args: + maximum_position: The maximum position to embed. Positions greater + than this value will be set to :obj:`maximum_position`. + reducer: A :class:`opennmt.layers.Reducer` to merge inputs and position + encodings. Defaults to :class:`opennmt.layers.SumReducer`. + **kwargs: Additional layer keyword arguments. + """ + super(PositionEmbedder, self).__init__(reducer=reducer, **kwargs) + self.maximum_position = maximum_position + self.embedding = None + + def build(self, input_shape): + shape = [self.maximum_position + 1, input_shape[-1]] + self.embedding = self.add_weight('position_embedding', shape) + super(PositionEmbedder, self).build(input_shape) + + def _encode(self, positions, depth): + positions = tf.minimum(positions, self.maximum_position) + return tf.nn.embedding_lookup(self.embedding, positions) + + +class SinusoidalPositionEncoder(PositionEncoder): + """Encodes positions with sine waves as described in + https://arxiv.org/abs/1706.03762. + """ + + def _encode(self, positions, depth): + if depth % 2 != 0: + raise ValueError( + 'SinusoidalPositionEncoder expects the depth to be divisble ' + 'by 2 but got %d' % depth) + + batch_size = tf.shape(positions)[0] + positions = tf.cast(positions, tf.float32) + + log_timescale_increment = math.log(10000) / (depth / 2 - 1) + inv_timescales = tf.exp( + tf.range(depth / 2, dtype=tf.float32) * -log_timescale_increment) + inv_timescales = tf.reshape( + tf.tile(inv_timescales, [batch_size]), [batch_size, depth // 2]) + scaled_time = tf.expand_dims(positions, -1) * tf.expand_dims( + inv_timescales, 1) + encoding = tf.concat( + [tf.sin(scaled_time), tf.cos(scaled_time)], axis=2) + return tf.cast(encoding, self.dtype) + + +class SinusodalPositionalEncoding(tf.keras.layers.Layer): + + def __init__(self, name='SinusodalPositionalEncoding'): + super(SinusodalPositionalEncoding, self).__init__(name=name) + + @staticmethod + def positional_encoding(len, dim, step=1.): + """ + :param len: int scalar + :param dim: int scalar + :param step: + :return: position embedding + """ + pos_mat = tf.tile( + tf.expand_dims( + tf.range(0, tf.cast(len, dtype=tf.float32), dtype=tf.float32) + * step, + axis=-1), [1, dim]) + dim_mat = tf.tile( + tf.expand_dims( + tf.range(0, tf.cast(dim, dtype=tf.float32), dtype=tf.float32), + axis=0), [len, 1]) + dim_mat_int = tf.cast(dim_mat, dtype=tf.int32) + pos_encoding = tf.where( # [time, dims] + tf.math.equal(tf.math.mod(dim_mat_int, 2), 0), + x=tf.math.sin( + pos_mat / tf.pow(10000., dim_mat / tf.cast(dim, tf.float32))), + y=tf.math.cos(pos_mat + / tf.pow(10000., + (dim_mat - 1) / tf.cast(dim, tf.float32)))) + return pos_encoding + + +class BatchSinusodalPositionalEncoding(tf.keras.layers.Layer): + + def __init__(self, name='BatchSinusodalPositionalEncoding'): + super(BatchSinusodalPositionalEncoding, self).__init__(name=name) + + @staticmethod + def positional_encoding(batch_size, len, dim, pos_mat, step=1.): + """ + :param len: int scalar + :param dim: int scalar + :param step: + :param pos_mat: [B, len] = [len, 1] * dim + :return: position embedding + """ + pos_mat = tf.tile( + tf.expand_dims(tf.cast(pos_mat, dtype=tf.float32) * step, axis=-1), + [1, 1, dim]) # [B, len, dim] + + dim_mat = tf.tile( + tf.expand_dims( + tf.expand_dims( + tf.range( + 0, tf.cast(dim, dtype=tf.float32), dtype=tf.float32), + axis=0), + axis=0), [batch_size, len, 1]) # [B, len, dim] + + dim_mat_int = tf.cast(dim_mat, dtype=tf.int32) + pos_encoding = tf.where( # [B, time, dims] + tf.math.equal(tf.mod(dim_mat_int, 2), 0), + x=tf.math.sin( + pos_mat / tf.pow(10000., dim_mat / tf.cast(dim, tf.float32))), + y=tf.math.cos(pos_mat + / tf.pow(10000., + (dim_mat - 1) / tf.cast(dim, tf.float32)))) + return pos_encoding diff --git a/modelscope/models/audio/tts/am/models/reducer.py b/modelscope/models/audio/tts/am/models/reducer.py new file mode 100755 index 00000000..a4c9ae17 --- /dev/null +++ b/modelscope/models/audio/tts/am/models/reducer.py @@ -0,0 +1,155 @@ +"""Define reducers: objects that merge inputs.""" + +import abc +import functools + +import tensorflow as tf + + +def pad_in_time(x, padding_length): + """Helper function to pad a tensor in the time dimension and retain the static depth dimension.""" + return tf.pad(x, [[0, 0], [0, padding_length], [0, 0]]) + + +def align_in_time(x, length): + """Aligns the time dimension of :obj:`x` with :obj:`length`.""" + time_dim = tf.shape(x)[1] + return tf.cond( + tf.less(time_dim, length), + true_fn=lambda: pad_in_time(x, length - time_dim), + false_fn=lambda: x[:, :length]) + + +def pad_with_identity(x, + sequence_length, + max_sequence_length, + identity_values=0, + maxlen=None): + """Pads a tensor with identity values up to :obj:`max_sequence_length`. + Args: + x: A ``tf.Tensor`` of shape ``[batch_size, time, depth]``. + sequence_length: The true sequence length of :obj:`x`. + max_sequence_length: The sequence length up to which the tensor must contain + :obj:`identity values`. + identity_values: The identity value. + maxlen: Size of the output time dimension. Default is the maximum value in + obj:`max_sequence_length`. + Returns: + A ``tf.Tensor`` of shape ``[batch_size, maxlen, depth]``. + """ + if maxlen is None: + maxlen = tf.reduce_max(max_sequence_length) + + mask = tf.sequence_mask(sequence_length, maxlen=maxlen, dtype=x.dtype) + mask = tf.expand_dims(mask, axis=-1) + mask_combined = tf.sequence_mask( + max_sequence_length, maxlen=maxlen, dtype=x.dtype) + mask_combined = tf.expand_dims(mask_combined, axis=-1) + + identity_mask = mask_combined * (1.0 - mask) + + x = pad_in_time(x, maxlen - tf.shape(x)[1]) + x = x * mask + (identity_mask * identity_values) + + return x + + +def pad_n_with_identity(inputs, sequence_lengths, identity_values=0): + """Pads each input tensors with identity values up to + ``max(sequence_lengths)`` for each batch. + Args: + inputs: A list of ``tf.Tensor``. + sequence_lengths: A list of sequence length. + identity_values: The identity value. + Returns: + A tuple ``(padded, max_sequence_length)`` which are respectively a list of + ``tf.Tensor`` where each tensor are padded with identity and the combined + sequence length. + """ + max_sequence_length = tf.reduce_max(sequence_lengths, axis=0) + maxlen = tf.reduce_max([tf.shape(x)[1] for x in inputs]) + padded = [ + pad_with_identity( + x, + length, + max_sequence_length, + identity_values=identity_values, + maxlen=maxlen) for x, length in zip(inputs, sequence_lengths) + ] + return padded, max_sequence_length + + +class Reducer(tf.keras.layers.Layer): + """Base class for reducers.""" + + def zip_and_reduce(self, x, y): + """Zips the :obj:`x` with :obj:`y` structures together and reduces all + elements. If the structures are nested, they will be flattened first. + Args: + x: The first structure. + y: The second structure. + Returns: + The same structure as :obj:`x` and :obj:`y` where each element from + :obj:`x` is reduced with the correspond element from :obj:`y`. + Raises: + ValueError: if the two structures are not the same. + """ + tf.nest.assert_same_structure(x, y) + x_flat = tf.nest.flatten(x) + y_flat = tf.nest.flatten(y) + reduced = list(map(self, zip(x_flat, y_flat))) + return tf.nest.pack_sequence_as(x, reduced) + + def call(self, inputs, sequence_length=None): # pylint: disable=arguments-differ + """Reduces all input elements. + Args: + inputs: A list of ``tf.Tensor``. + sequence_length: The length of each input, if reducing sequences. + Returns: + If :obj:`sequence_length` is set, a tuple + ``(reduced_input, reduced_length)``, otherwise a reduced ``tf.Tensor`` + only. + """ + if sequence_length is None: + return self.reduce(inputs) + else: + return self.reduce_sequence( + inputs, sequence_lengths=sequence_length) + + @abc.abstractmethod + def reduce(self, inputs): + """See :meth:`opennmt.layers.Reducer.__call__`.""" + raise NotImplementedError() + + @abc.abstractmethod + def reduce_sequence(self, inputs, sequence_lengths): + """See :meth:`opennmt.layers.Reducer.__call__`.""" + raise NotImplementedError() + + +class SumReducer(Reducer): + """A reducer that sums the inputs.""" + + def reduce(self, inputs): + if len(inputs) == 1: + return inputs[0] + if len(inputs) == 2: + return inputs[0] + inputs[1] + return tf.add_n(inputs) + + def reduce_sequence(self, inputs, sequence_lengths): + padded, combined_length = pad_n_with_identity( + inputs, sequence_lengths, identity_values=0) + return self.reduce(padded), combined_length + + +class MultiplyReducer(Reducer): + """A reducer that multiplies the inputs.""" + + def reduce(self, inputs): + return functools.reduce(lambda a, x: a * x, inputs) + + def reduce_sequence(self, inputs, sequence_lengths): + padded, combined_length = pad_n_with_identity( + inputs, sequence_lengths, identity_values=1) + return self.reduce(padded), combined_length diff --git a/modelscope/models/audio/tts/am/models/rnn_wrappers.py b/modelscope/models/audio/tts/am/models/rnn_wrappers.py new file mode 100755 index 00000000..8f0d612b --- /dev/null +++ b/modelscope/models/audio/tts/am/models/rnn_wrappers.py @@ -0,0 +1,240 @@ +import numpy as np +import tensorflow as tf +from tensorflow.contrib.rnn import RNNCell +from tensorflow.contrib.seq2seq import AttentionWrapperState +from tensorflow.python.ops import rnn_cell_impl + +from .modules import prenet + + +class VarPredictorCell(RNNCell): + '''Wrapper wrapper knock knock.''' + + def __init__(self, var_predictor_cell, is_training, dim, prenet_units): + super(VarPredictorCell, self).__init__() + self._var_predictor_cell = var_predictor_cell + self._is_training = is_training + self._dim = dim + self._prenet_units = prenet_units + + @property + def state_size(self): + return tuple([self.output_size, self._var_predictor_cell.state_size]) + + @property + def output_size(self): + return self._dim + + def zero_state(self, batch_size, dtype): + return tuple([ + rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, + dtype), + self._var_predictor_cell.zero_state(batch_size, dtype) + ]) + + def call(self, inputs, state): + '''Run the Tacotron2 super decoder cell.''' + super_cell_out, decoder_state = state + + # split + prenet_input = inputs[:, 0:self._dim] + encoder_output = inputs[:, self._dim:] + + # prenet and concat + prenet_output = prenet( + prenet_input, + self._prenet_units, + self._is_training, + scope='var_prenet') + decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) + + # decoder LSTM/GRU + new_super_cell_out, new_decoder_state = self._var_predictor_cell( + decoder_input, decoder_state) + + # projection + new_super_cell_out = tf.layers.dense( + new_super_cell_out, units=self._dim) + + new_states = tuple([new_super_cell_out, new_decoder_state]) + + return new_super_cell_out, new_states + + +class DurPredictorCell(RNNCell): + '''Wrapper wrapper knock knock.''' + + def __init__(self, var_predictor_cell, is_training, dim, prenet_units): + super(DurPredictorCell, self).__init__() + self._var_predictor_cell = var_predictor_cell + self._is_training = is_training + self._dim = dim + self._prenet_units = prenet_units + + @property + def state_size(self): + return tuple([self.output_size, self._var_predictor_cell.state_size]) + + @property + def output_size(self): + return self._dim + + def zero_state(self, batch_size, dtype): + return tuple([ + rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, + dtype), + self._var_predictor_cell.zero_state(batch_size, dtype) + ]) + + def call(self, inputs, state): + '''Run the Tacotron2 super decoder cell.''' + super_cell_out, decoder_state = state + + # split + prenet_input = inputs[:, 0:self._dim] + encoder_output = inputs[:, self._dim:] + + # prenet and concat + prenet_output = prenet( + prenet_input, + self._prenet_units, + self._is_training, + scope='dur_prenet') + decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) + + # decoder LSTM/GRU + new_super_cell_out, new_decoder_state = self._var_predictor_cell( + decoder_input, decoder_state) + + # projection + new_super_cell_out = tf.layers.dense( + new_super_cell_out, units=self._dim) + new_super_cell_out = tf.nn.relu(new_super_cell_out) + # new_super_cell_out = tf.log(tf.cast(tf.round(tf.exp(new_super_cell_out) - 1), tf.float32) + 1) + + new_states = tuple([new_super_cell_out, new_decoder_state]) + + return new_super_cell_out, new_states + + +class DurPredictorCECell(RNNCell): + '''Wrapper wrapper knock knock.''' + + def __init__(self, var_predictor_cell, is_training, dim, prenet_units, + max_dur, dur_embedding_dim): + super(DurPredictorCECell, self).__init__() + self._var_predictor_cell = var_predictor_cell + self._is_training = is_training + self._dim = dim + self._prenet_units = prenet_units + self._max_dur = max_dur + self._dur_embedding_dim = dur_embedding_dim + + @property + def state_size(self): + return tuple([self.output_size, self._var_predictor_cell.state_size]) + + @property + def output_size(self): + return self._max_dur + + def zero_state(self, batch_size, dtype): + return tuple([ + rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, + dtype), + self._var_predictor_cell.zero_state(batch_size, dtype) + ]) + + def call(self, inputs, state): + '''Run the Tacotron2 super decoder cell.''' + super_cell_out, decoder_state = state + + # split + prenet_input = tf.squeeze( + tf.cast(inputs[:, 0:self._dim], tf.int32), axis=-1) # [N] + prenet_input = tf.one_hot( + prenet_input, self._max_dur, on_value=1.0, off_value=0.0, + axis=-1) # [N, 120] + prenet_input = tf.layers.dense( + prenet_input, units=self._dur_embedding_dim) + encoder_output = inputs[:, self._dim:] + + # prenet and concat + prenet_output = prenet( + prenet_input, + self._prenet_units, + self._is_training, + scope='dur_prenet') + decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) + + # decoder LSTM/GRU + new_super_cell_out, new_decoder_state = self._var_predictor_cell( + decoder_input, decoder_state) + + # projection + new_super_cell_out = tf.layers.dense( + new_super_cell_out, units=self._max_dur) # [N, 120] + new_super_cell_out = tf.nn.softmax(new_super_cell_out) # [N, 120] + + new_states = tuple([new_super_cell_out, new_decoder_state]) + + return new_super_cell_out, new_states + + +class VarPredictorCell2(RNNCell): + '''Wrapper wrapper knock knock.''' + + def __init__(self, var_predictor_cell, is_training, dim, prenet_units): + super(VarPredictorCell2, self).__init__() + self._var_predictor_cell = var_predictor_cell + self._is_training = is_training + self._dim = dim + self._prenet_units = prenet_units + + @property + def state_size(self): + return tuple([self.output_size, self._var_predictor_cell.state_size]) + + @property + def output_size(self): + return self._dim + + def zero_state(self, batch_size, dtype): + return tuple([ + rnn_cell_impl._zero_state_tensors(self.output_size, batch_size, + dtype), + self._var_predictor_cell.zero_state(batch_size, dtype) + ]) + + def call(self, inputs, state): + '''Run the Tacotron2 super decoder cell.''' + super_cell_out, decoder_state = state + + # split + prenet_input = inputs[:, 0:self._dim] + encoder_output = inputs[:, self._dim:] + + # prenet and concat + prenet_output = prenet( + prenet_input, + self._prenet_units, + self._is_training, + scope='var_prenet') + decoder_input = tf.concat([prenet_output, encoder_output], axis=-1) + + # decoder LSTM/GRU + new_super_cell_out, new_decoder_state = self._var_predictor_cell( + decoder_input, decoder_state) + + # projection + new_super_cell_out = tf.layers.dense( + new_super_cell_out, units=self._dim) + + # split and relu + new_super_cell_out = tf.concat([ + tf.nn.relu(new_super_cell_out[:, 0:1]), new_super_cell_out[:, 1:] + ], axis=-1) # yapf:disable + + new_states = tuple([new_super_cell_out, new_decoder_state]) + + return new_super_cell_out, new_states diff --git a/modelscope/models/audio/tts/am/models/robutrans.py b/modelscope/models/audio/tts/am/models/robutrans.py new file mode 100755 index 00000000..34b4da7a --- /dev/null +++ b/modelscope/models/audio/tts/am/models/robutrans.py @@ -0,0 +1,760 @@ +import tensorflow as tf +from tensorflow.contrib.rnn import LSTMBlockCell, MultiRNNCell +from tensorflow.contrib.seq2seq import BasicDecoder +from tensorflow.python.ops.ragged.ragged_util import repeat + +from .fsmn_encoder import FsmnEncoderV2 +from .helpers import VarTestHelper, VarTrainingHelper +from .modules import conv_prenet, decoder_prenet, encoder_prenet +from .position import (BatchSinusodalPositionalEncoding, + SinusodalPositionalEncoding) +from .rnn_wrappers import DurPredictorCell, VarPredictorCell +from .self_attention_decoder import SelfAttentionDecoder +from .self_attention_encoder import SelfAttentionEncoder + + +class RobuTrans(): + + def __init__(self, hparams): + self._hparams = hparams + + def initialize(self, + inputs, + inputs_emotion, + inputs_speaker, + input_lengths, + output_lengths=None, + mel_targets=None, + durations=None, + pitch_contours=None, + uv_masks=None, + pitch_scales=None, + duration_scales=None, + energy_contours=None, + energy_scales=None): + '''Initializes the model for inference. + + Sets "mel_outputs", "linear_outputs", "stop_token_outputs", and "alignments" fields. + + Args: + inputs: int32 Tensor with shape [N, T_in] where N is batch size, T_in is number of + steps in the input time series, and values are character IDs + input_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths + of each sequence in inputs. + output_lengths: int32 Tensor with shape [N] where N is batch size and values are the lengths + of each sequence in outputs. + mel_targets: float32 Tensor with shape [N, T_out, M] where N is batch size, T_out is number + of steps in the output time series, M is num_mels, and values are entries in the mel + spectrogram. Only needed for training. + ''' + with tf.variable_scope('inference') as _: + is_training = mel_targets is not None + batch_size = tf.shape(inputs)[0] + hp = self._hparams + + input_mask = None + if input_lengths is not None and is_training: + input_mask = tf.sequence_mask( + input_lengths, tf.shape(inputs)[1], dtype=tf.float32) + + if input_mask is not None: + inputs = inputs * tf.expand_dims(input_mask, -1) + + # speaker embedding + embedded_inputs_speaker = tf.layers.dense( + inputs_speaker, + 32, + activation=None, + use_bias=False, + kernel_initializer=tf.truncated_normal_initializer(stddev=0.5)) + + # emotion embedding + embedded_inputs_emotion = tf.layers.dense( + inputs_emotion, + 32, + activation=None, + use_bias=False, + kernel_initializer=tf.truncated_normal_initializer(stddev=0.5)) + + # symbol embedding + with tf.variable_scope('Embedding'): + embedded_inputs = tf.layers.dense( + inputs, + hp.embedding_dim, + activation=None, + use_bias=False, + kernel_initializer=tf.truncated_normal_initializer( + stddev=0.5)) + + # Encoder + with tf.variable_scope('Encoder'): + Encoder = SelfAttentionEncoder( + num_layers=hp.encoder_num_layers, + num_units=hp.encoder_num_units, + num_heads=hp.encoder_num_heads, + ffn_inner_dim=hp.encoder_ffn_inner_dim, + dropout=hp.encoder_dropout, + attention_dropout=hp.encoder_attention_dropout, + relu_dropout=hp.encoder_relu_dropout) + encoder_outputs, state_mo, sequence_length_mo, attns = Encoder.encode( + embedded_inputs, + sequence_length=input_lengths, + mode=is_training) + encoder_outputs = tf.layers.dense( + encoder_outputs, + hp.encoder_projection_units, + activation=None, + use_bias=False, + kernel_initializer=tf.truncated_normal_initializer( + stddev=0.5)) + + # pitch and energy + var_inputs = tf.concat([ + encoder_outputs, embedded_inputs_speaker, + embedded_inputs_emotion + ], 2) + if input_mask is not None: + var_inputs = var_inputs * tf.expand_dims(input_mask, -1) + + with tf.variable_scope('Pitch_Predictor'): + Pitch_Predictor_FSMN = FsmnEncoderV2( + filter_size=hp.predictor_filter_size, + fsmn_num_layers=hp.predictor_fsmn_num_layers, + dnn_num_layers=hp.predictor_dnn_num_layers, + num_memory_units=hp.predictor_num_memory_units, + ffn_inner_dim=hp.predictor_ffn_inner_dim, + dropout=hp.predictor_dropout, + shift=hp.predictor_shift, + position_encoder=None) + pitch_contour_outputs, _, _ = Pitch_Predictor_FSMN.encode( + tf.concat([ + encoder_outputs, embedded_inputs_speaker, + embedded_inputs_emotion + ], 2), + sequence_length=input_lengths, + mode=is_training) + pitch_contour_outputs, _ = tf.nn.bidirectional_dynamic_rnn( + LSTMBlockCell(hp.predictor_lstm_units), + LSTMBlockCell(hp.predictor_lstm_units), + pitch_contour_outputs, + sequence_length=input_lengths, + dtype=tf.float32) + pitch_contour_outputs = tf.concat( + pitch_contour_outputs, axis=-1) + pitch_contour_outputs = tf.layers.dense( + pitch_contour_outputs, units=1) # [N, T_in, 1] + pitch_contour_outputs = tf.squeeze( + pitch_contour_outputs, axis=2) # [N, T_in] + + with tf.variable_scope('Energy_Predictor'): + Energy_Predictor_FSMN = FsmnEncoderV2( + filter_size=hp.predictor_filter_size, + fsmn_num_layers=hp.predictor_fsmn_num_layers, + dnn_num_layers=hp.predictor_dnn_num_layers, + num_memory_units=hp.predictor_num_memory_units, + ffn_inner_dim=hp.predictor_ffn_inner_dim, + dropout=hp.predictor_dropout, + shift=hp.predictor_shift, + position_encoder=None) + energy_contour_outputs, _, _ = Energy_Predictor_FSMN.encode( + tf.concat([ + encoder_outputs, embedded_inputs_speaker, + embedded_inputs_emotion + ], 2), + sequence_length=input_lengths, + mode=is_training) + energy_contour_outputs, _ = tf.nn.bidirectional_dynamic_rnn( + LSTMBlockCell(hp.predictor_lstm_units), + LSTMBlockCell(hp.predictor_lstm_units), + energy_contour_outputs, + sequence_length=input_lengths, + dtype=tf.float32) + energy_contour_outputs = tf.concat( + energy_contour_outputs, axis=-1) + energy_contour_outputs = tf.layers.dense( + energy_contour_outputs, units=1) # [N, T_in, 1] + energy_contour_outputs = tf.squeeze( + energy_contour_outputs, axis=2) # [N, T_in] + + if is_training: + pitch_embeddings = tf.expand_dims( + pitch_contours, axis=2) # [N, T_in, 1] + pitch_embeddings = tf.layers.conv1d( + pitch_embeddings, + filters=hp.encoder_projection_units, + kernel_size=9, + padding='same', + name='pitch_embeddings') # [N, T_in, 32] + + energy_embeddings = tf.expand_dims( + energy_contours, axis=2) # [N, T_in, 1] + energy_embeddings = tf.layers.conv1d( + energy_embeddings, + filters=hp.encoder_projection_units, + kernel_size=9, + padding='same', + name='energy_embeddings') # [N, T_in, 32] + else: + pitch_contour_outputs *= pitch_scales + pitch_embeddings = tf.expand_dims( + pitch_contour_outputs, axis=2) # [N, T_in, 1] + pitch_embeddings = tf.layers.conv1d( + pitch_embeddings, + filters=hp.encoder_projection_units, + kernel_size=9, + padding='same', + name='pitch_embeddings') # [N, T_in, 32] + + energy_contour_outputs *= energy_scales + energy_embeddings = tf.expand_dims( + energy_contour_outputs, axis=2) # [N, T_in, 1] + energy_embeddings = tf.layers.conv1d( + energy_embeddings, + filters=hp.encoder_projection_units, + kernel_size=9, + padding='same', + name='energy_embeddings') # [N, T_in, 32] + + encoder_outputs_ = encoder_outputs + pitch_embeddings + energy_embeddings + + # duration + dur_inputs = tf.concat([ + encoder_outputs_, embedded_inputs_speaker, + embedded_inputs_emotion + ], 2) + if input_mask is not None: + dur_inputs = dur_inputs * tf.expand_dims(input_mask, -1) + with tf.variable_scope('Duration_Predictor'): + duration_predictor_cell = MultiRNNCell([ + LSTMBlockCell(hp.predictor_lstm_units), + LSTMBlockCell(hp.predictor_lstm_units) + ], state_is_tuple=True) # yapf:disable + duration_output_cell = DurPredictorCell( + duration_predictor_cell, is_training, 1, + hp.predictor_prenet_units) + duration_predictor_init_state = duration_output_cell.zero_state( + batch_size=batch_size, dtype=tf.float32) + if is_training: + duration_helper = VarTrainingHelper( + tf.expand_dims( + tf.log(tf.cast(durations, tf.float32) + 1), + axis=2), dur_inputs, 1) + else: + duration_helper = VarTestHelper(batch_size, dur_inputs, 1) + ( + duration_outputs, _ + ), final_duration_predictor_state, _ = tf.contrib.seq2seq.dynamic_decode( + BasicDecoder(duration_output_cell, duration_helper, + duration_predictor_init_state), + maximum_iterations=1000) + duration_outputs = tf.squeeze( + duration_outputs, axis=2) # [N, T_in] + if input_mask is not None: + duration_outputs = duration_outputs * input_mask + duration_outputs_ = tf.exp(duration_outputs) - 1 + + # Length Regulator + with tf.variable_scope('Length_Regulator'): + if is_training: + i = tf.constant(1) + # position embedding + j = tf.constant(1) + dur_len = tf.shape(durations)[-1] + embedded_position_i = tf.range(1, durations[0, 0] + 1) + + def condition_pos(j, e): + return tf.less(j, dur_len) + + def loop_body_pos(j, embedded_position_i): + embedded_position_i = tf.concat([ + embedded_position_i, + tf.range(1, durations[0, j] + 1) + ], axis=0) # yapf:disable + return [j + 1, embedded_position_i] + + j, embedded_position_i = tf.while_loop( + condition_pos, + loop_body_pos, [j, embedded_position_i], + shape_invariants=[ + j.get_shape(), + tf.TensorShape([None]) + ]) + embedded_position = tf.reshape(embedded_position_i, + (1, -1)) + + # others + LR_outputs = repeat( + encoder_outputs_[0:1, :, :], durations[0, :], axis=1) + embedded_outputs_speaker = repeat( + embedded_inputs_speaker[0:1, :, :], + durations[0, :], + axis=1) + embedded_outputs_emotion = repeat( + embedded_inputs_emotion[0:1, :, :], + durations[0, :], + axis=1) + + def condition(i, pos, layer, s, e): + return tf.less(i, tf.shape(mel_targets)[0]) + + def loop_body(i, embedded_position, LR_outputs, + embedded_outputs_speaker, + embedded_outputs_emotion): + # position embedding + jj = tf.constant(1) + embedded_position_i = tf.range(1, durations[i, 0] + 1) + + def condition_pos_i(j, e): + return tf.less(j, dur_len) + + def loop_body_pos_i(j, embedded_position_i): + embedded_position_i = tf.concat([ + embedded_position_i, + tf.range(1, durations[i, j] + 1) + ], axis=0) # yapf:disable + return [j + 1, embedded_position_i] + + jj, embedded_position_i = tf.while_loop( + condition_pos_i, + loop_body_pos_i, [jj, embedded_position_i], + shape_invariants=[ + jj.get_shape(), + tf.TensorShape([None]) + ]) + embedded_position = tf.concat([ + embedded_position, + tf.reshape(embedded_position_i, (1, -1)) + ], 0) + + # others + LR_outputs = tf.concat([ + LR_outputs, + repeat( + encoder_outputs_[i:i + 1, :, :], + durations[i, :], + axis=1) + ], 0) + embedded_outputs_speaker = tf.concat([ + embedded_outputs_speaker, + repeat( + embedded_inputs_speaker[i:i + 1, :, :], + durations[i, :], + axis=1) + ], 0) + embedded_outputs_emotion = tf.concat([ + embedded_outputs_emotion, + repeat( + embedded_inputs_emotion[i:i + 1, :, :], + durations[i, :], + axis=1) + ], 0) + return [ + i + 1, embedded_position, LR_outputs, + embedded_outputs_speaker, embedded_outputs_emotion + ] + + i, embedded_position, LR_outputs, + embedded_outputs_speaker, + embedded_outputs_emotion = tf.while_loop( + condition, + loop_body, [ + i, embedded_position, LR_outputs, + embedded_outputs_speaker, embedded_outputs_emotion + ], + shape_invariants=[ + i.get_shape(), + tf.TensorShape([None, None]), + tf.TensorShape([None, None, None]), + tf.TensorShape([None, None, None]), + tf.TensorShape([None, None, None]) + ], + parallel_iterations=hp.batch_size) + + ori_framenum = tf.shape(mel_targets)[1] + else: + # position + j = tf.constant(1) + dur_len = tf.shape(duration_outputs_)[-1] + embedded_position_i = tf.range( + 1, + tf.cast(tf.round(duration_outputs_)[0, 0], tf.int32) + + 1) + + def condition_pos(j, e): + return tf.less(j, dur_len) + + def loop_body_pos(j, embedded_position_i): + embedded_position_i = tf.concat([ + embedded_position_i, + tf.range( + 1, + tf.cast( + tf.round(duration_outputs_)[0, j], + tf.int32) + 1) + ], axis=0) # yapf:disable + return [j + 1, embedded_position_i] + + j, embedded_position_i = tf.while_loop( + condition_pos, + loop_body_pos, [j, embedded_position_i], + shape_invariants=[ + j.get_shape(), + tf.TensorShape([None]) + ]) + embedded_position = tf.reshape(embedded_position_i, + (1, -1)) + # others + duration_outputs_ *= duration_scales + LR_outputs = repeat( + encoder_outputs_[0:1, :, :], + tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), + axis=1) + embedded_outputs_speaker = repeat( + embedded_inputs_speaker[0:1, :, :], + tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), + axis=1) + embedded_outputs_emotion = repeat( + embedded_inputs_emotion[0:1, :, :], + tf.cast(tf.round(duration_outputs_)[0, :], tf.int32), + axis=1) + ori_framenum = tf.shape(LR_outputs)[1] + + left = hp.outputs_per_step - tf.mod( + ori_framenum, hp.outputs_per_step) + LR_outputs = tf.cond( + tf.equal(left, + hp.outputs_per_step), lambda: LR_outputs, + lambda: tf.pad(LR_outputs, [[0, 0], [0, left], [0, 0]], + 'CONSTANT')) + embedded_outputs_speaker = tf.cond( + tf.equal(left, hp.outputs_per_step), + lambda: embedded_outputs_speaker, lambda: tf.pad( + embedded_outputs_speaker, [[0, 0], [0, left], + [0, 0]], 'CONSTANT')) + embedded_outputs_emotion = tf.cond( + tf.equal(left, hp.outputs_per_step), + lambda: embedded_outputs_emotion, lambda: tf.pad( + embedded_outputs_emotion, [[0, 0], [0, left], + [0, 0]], 'CONSTANT')) + embedded_position = tf.cond( + tf.equal(left, hp.outputs_per_step), + lambda: embedded_position, + lambda: tf.pad(embedded_position, [[0, 0], [0, left]], + 'CONSTANT')) + + # Pos_Embedding + with tf.variable_scope('Position_Embedding'): + Pos_Embedding = BatchSinusodalPositionalEncoding() + position_embeddings = Pos_Embedding.positional_encoding( + batch_size, + tf.shape(LR_outputs)[1], hp.encoder_projection_units, + embedded_position) + LR_outputs += position_embeddings + + # multi-frame + LR_outputs = tf.reshape(LR_outputs, [ + batch_size, -1, + hp.outputs_per_step * hp.encoder_projection_units + ]) + embedded_outputs_speaker = tf.reshape( + embedded_outputs_speaker, + [batch_size, -1, hp.outputs_per_step * 32])[:, :, :32] + embedded_outputs_emotion = tf.reshape( + embedded_outputs_emotion, + [batch_size, -1, hp.outputs_per_step * 32])[:, :, :32] + # [N, T_out, D_LR_outputs] (D_LR_outputs = hp.outputs_per_step * hp.encoder_projection_units + 64) + LR_outputs = tf.concat([ + LR_outputs, embedded_outputs_speaker, embedded_outputs_emotion + ], -1) + + # auto bandwidth + if is_training: + durations_mask = tf.cast(durations, + tf.float32) * input_mask # [N, T_in] + else: + durations_mask = duration_outputs_ + X_band_width = tf.cast( + tf.round(tf.reduce_max(durations_mask) / hp.outputs_per_step), + tf.int32) + H_band_width = X_band_width + + with tf.variable_scope('Decoder'): + Decoder = SelfAttentionDecoder( + num_layers=hp.decoder_num_layers, + num_units=hp.decoder_num_units, + num_heads=hp.decoder_num_heads, + ffn_inner_dim=hp.decoder_ffn_inner_dim, + dropout=hp.decoder_dropout, + attention_dropout=hp.decoder_attention_dropout, + relu_dropout=hp.decoder_relu_dropout, + prenet_units=hp.prenet_units, + dense_units=hp.prenet_proj_units, + num_mels=hp.num_mels, + outputs_per_step=hp.outputs_per_step, + X_band_width=X_band_width, + H_band_width=H_band_width, + position_encoder=None) + if is_training: + if hp.free_run: + r = hp.outputs_per_step + init_decoder_input = tf.expand_dims( + tf.tile([[0.0]], [batch_size, hp.num_mels]), + axis=1) # [N, 1, hp.num_mels] + decoder_input_lengths = tf.cast( + output_lengths / r, tf.int32) + decoder_outputs, attention_x, attention_h = Decoder.dynamic_decode_and_search( + init_decoder_input, + maximum_iterations=tf.shape(LR_outputs)[1], + mode=is_training, + memory=LR_outputs, + memory_sequence_length=decoder_input_lengths) + else: + r = hp.outputs_per_step + decoder_input = mel_targets[:, r - 1:: + r, :] # [N, T_out / r, hp.num_mels] + init_decoder_input = tf.expand_dims( + tf.tile([[0.0]], [batch_size, hp.num_mels]), + axis=1) # [N, 1, hp.num_mels] + decoder_input = tf.concat( + [init_decoder_input, decoder_input], + axis=1) # [N, T_out / r + 1, hp.num_mels] + decoder_input = decoder_input[:, : + -1, :] # [N, T_out / r, hp.num_mels] + decoder_input_lengths = tf.cast( + output_lengths / r, tf.int32) + decoder_outputs, attention_x, attention_h = Decoder.decode_from_inputs( + decoder_input, + decoder_input_lengths, + mode=is_training, + memory=LR_outputs, + memory_sequence_length=decoder_input_lengths) + else: + init_decoder_input = tf.expand_dims( + tf.tile([[0.0]], [batch_size, hp.num_mels]), + axis=1) # [N, 1, hp.num_mels] + decoder_outputs, attention_x, attention_h = Decoder.dynamic_decode_and_search( + init_decoder_input, + maximum_iterations=tf.shape(LR_outputs)[1], + mode=is_training, + memory=LR_outputs, + memory_sequence_length=tf.expand_dims( + tf.shape(LR_outputs)[1], axis=0)) + + if is_training: + mel_outputs_ = tf.reshape(decoder_outputs, + [batch_size, -1, hp.num_mels]) + else: + mel_outputs_ = tf.reshape( + decoder_outputs, + [batch_size, -1, hp.num_mels])[:, :ori_framenum, :] + mel_outputs = mel_outputs_ + + with tf.variable_scope('Postnet'): + Postnet_FSMN = FsmnEncoderV2( + filter_size=hp.postnet_filter_size, + fsmn_num_layers=hp.postnet_fsmn_num_layers, + dnn_num_layers=hp.postnet_dnn_num_layers, + num_memory_units=hp.postnet_num_memory_units, + ffn_inner_dim=hp.postnet_ffn_inner_dim, + dropout=hp.postnet_dropout, + shift=hp.postnet_shift, + position_encoder=None) + if is_training: + postnet_fsmn_outputs, _, _ = Postnet_FSMN.encode( + mel_outputs, + sequence_length=output_lengths, + mode=is_training) + hidden_lstm_outputs, _ = tf.nn.dynamic_rnn( + LSTMBlockCell(hp.postnet_lstm_units), + postnet_fsmn_outputs, + sequence_length=output_lengths, + dtype=tf.float32) + else: + postnet_fsmn_outputs, _, _ = Postnet_FSMN.encode( + mel_outputs, + sequence_length=[tf.shape(mel_outputs_)[1]], + mode=is_training) + hidden_lstm_outputs, _ = tf.nn.dynamic_rnn( + LSTMBlockCell(hp.postnet_lstm_units), + postnet_fsmn_outputs, + sequence_length=[tf.shape(mel_outputs_)[1]], + dtype=tf.float32) + + mel_residual_outputs = tf.layers.dense( + hidden_lstm_outputs, units=hp.num_mels) + mel_outputs += mel_residual_outputs + + self.inputs = inputs + self.inputs_speaker = inputs_speaker + self.inputs_emotion = inputs_emotion + self.input_lengths = input_lengths + self.durations = durations + self.output_lengths = output_lengths + self.mel_outputs_ = mel_outputs_ + self.mel_outputs = mel_outputs + self.mel_targets = mel_targets + self.duration_outputs = duration_outputs + self.duration_outputs_ = duration_outputs_ + self.duration_scales = duration_scales + self.pitch_contour_outputs = pitch_contour_outputs + self.pitch_contours = pitch_contours + self.pitch_scales = pitch_scales + self.energy_contour_outputs = energy_contour_outputs + self.energy_contours = energy_contours + self.energy_scales = energy_scales + self.uv_masks_ = uv_masks + + self.embedded_inputs_emotion = embedded_inputs_emotion + self.embedding_fsmn_outputs = embedded_inputs + self.encoder_outputs = encoder_outputs + self.encoder_outputs_ = encoder_outputs_ + self.LR_outputs = LR_outputs + self.postnet_fsmn_outputs = postnet_fsmn_outputs + + self.pitch_embeddings = pitch_embeddings + self.energy_embeddings = energy_embeddings + + self.attns = attns + self.attention_x = attention_x + self.attention_h = attention_h + self.X_band_width = X_band_width + self.H_band_width = H_band_width + + def add_loss(self): + '''Adds loss to the model. Sets "loss" field. initialize must have been called.''' + with tf.variable_scope('loss') as _: + hp = self._hparams + mask = tf.sequence_mask( + self.output_lengths, + tf.shape(self.mel_targets)[1], + dtype=tf.float32) + valid_outputs = tf.reduce_sum(mask) + + mask_input = tf.sequence_mask( + self.input_lengths, + tf.shape(self.durations)[1], + dtype=tf.float32) + valid_inputs = tf.reduce_sum(mask_input) + + # mel loss + if self.uv_masks_ is not None: + valid_outputs_mask = tf.reduce_sum( + tf.expand_dims(mask, -1) * self.uv_masks_) + self.mel_loss_ = tf.reduce_sum( + tf.abs(self.mel_targets - self.mel_outputs_) + * tf.expand_dims(mask, -1) * self.uv_masks_) / ( + valid_outputs_mask * hp.num_mels) + self.mel_loss = tf.reduce_sum( + tf.abs(self.mel_targets - self.mel_outputs) + * tf.expand_dims(mask, -1) * self.uv_masks_) / ( + valid_outputs_mask * hp.num_mels) + else: + self.mel_loss_ = tf.reduce_sum( + tf.abs(self.mel_targets - self.mel_outputs_) + * tf.expand_dims(mask, -1)) / ( + valid_outputs * hp.num_mels) + self.mel_loss = tf.reduce_sum( + tf.abs(self.mel_targets - self.mel_outputs) + * tf.expand_dims(mask, -1)) / ( + valid_outputs * hp.num_mels) + + # duration loss + self.duration_loss = tf.reduce_sum( + tf.abs( + tf.log(tf.cast(self.durations, tf.float32) + 1) + - self.duration_outputs) * mask_input) / valid_inputs + + # pitch contour loss + self.pitch_contour_loss = tf.reduce_sum( + tf.abs(self.pitch_contours - self.pitch_contour_outputs) + * mask_input) / valid_inputs + + # energy contour loss + self.energy_contour_loss = tf.reduce_sum( + tf.abs(self.energy_contours - self.energy_contour_outputs) + * mask_input) / valid_inputs + + # final loss + self.loss = self.mel_loss_ + self.mel_loss + self.duration_loss \ + + self.pitch_contour_loss + self.energy_contour_loss + + # guided attention loss + self.guided_attention_loss = tf.constant(0.0) + if hp.guided_attention: + i0 = tf.constant(0) + loss0 = tf.constant(0.0) + + def c(i, _): + return tf.less(i, tf.shape(mel_targets)[0]) + + def loop_body(i, loss): + decoder_input_lengths = tf.cast( + self.output_lengths / hp.outputs_per_step, tf.int32) + input_len = decoder_input_lengths[i] + output_len = decoder_input_lengths[i] + input_w = tf.expand_dims( + tf.range(tf.cast(input_len, dtype=tf.float32)), + axis=1) / tf.cast( + input_len, dtype=tf.float32) # [T_in, 1] + output_w = tf.expand_dims( + tf.range(tf.cast(output_len, dtype=tf.float32)), + axis=0) / tf.cast( + output_len, dtype=tf.float32) # [1, T_out] + guided_attention_w = 1.0 - tf.exp( + -(1 / hp.guided_attention_2g_squared) + * tf.square(input_w - output_w)) # [T_in, T_out] + guided_attention_w = tf.expand_dims( + guided_attention_w, axis=0) # [1, T_in, T_out] + # [hp.decoder_num_heads, T_in, T_out] + guided_attention_w = tf.tile(guided_attention_w, + [hp.decoder_num_heads, 1, 1]) + loss_i = tf.constant(0.0) + for j in range(hp.decoder_num_layers): + loss_i += tf.reduce_mean( + self.attention_h[j][i, :, :input_len, :output_len] + * guided_attention_w) + + return [tf.add(i, 1), tf.add(loss, loss_i)] + + _, loss = tf.while_loop( + c, + loop_body, + loop_vars=[i0, loss0], + parallel_iterations=hp.batch_size) + self.guided_attention_loss = loss / hp.batch_size + self.loss += hp.guided_attention_loss_weight * self.guided_attention_loss + + def add_optimizer(self, global_step): + '''Adds optimizer. Sets "gradients" and "optimize" fields. add_loss must have been called. + + Args: + global_step: int32 scalar Tensor representing current global step in training + ''' + with tf.variable_scope('optimizer') as _: + hp = self._hparams + if hp.decay_learning_rate: + self.learning_rate = _learning_rate_decay( + hp.initial_learning_rate, global_step) + else: + self.learning_rate = tf.convert_to_tensor( + hp.initial_learning_rate) + optimizer = tf.train.AdamOptimizer(self.learning_rate, + hp.adam_beta1, hp.adam_beta2) + gradients, variables = zip(*optimizer.compute_gradients(self.loss)) + self.gradients = gradients + clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0) + + # Add dependency on UPDATE_OPS; otherwise batchnorm won't work correctly. See: + # https://github.com/tensorflow/tensorflow/issues/1122 + with tf.control_dependencies( + tf.get_collection(tf.GraphKeys.UPDATE_OPS)): + self.optimize = optimizer.apply_gradients( + zip(clipped_gradients, variables), global_step=global_step) + + +def _learning_rate_decay(init_lr, global_step): + # Noam scheme from tensor2tensor: + warmup_steps = 4000.0 + step = tf.cast(global_step + 1, dtype=tf.float32) + return init_lr * warmup_steps**0.5 * tf.minimum(step * warmup_steps**-1.5, + step**-0.5) diff --git a/modelscope/models/audio/tts/am/models/self_attention_decoder.py b/modelscope/models/audio/tts/am/models/self_attention_decoder.py new file mode 100755 index 00000000..4e64342c --- /dev/null +++ b/modelscope/models/audio/tts/am/models/self_attention_decoder.py @@ -0,0 +1,817 @@ +"""Define self-attention decoder.""" + +import sys + +import tensorflow as tf + +from . import compat, transformer +from .modules import decoder_prenet +from .position import SinusoidalPositionEncoder + + +class SelfAttentionDecoder(): + """Decoder using self-attention as described in + https://arxiv.org/abs/1706.03762. + """ + + def __init__(self, + num_layers, + num_units=512, + num_heads=8, + ffn_inner_dim=2048, + dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + prenet_units=256, + dense_units=128, + num_mels=80, + outputs_per_step=3, + X_band_width=None, + H_band_width=None, + position_encoder=SinusoidalPositionEncoder(), + self_attention_type='scaled_dot'): + """Initializes the parameters of the decoder. + + Args: + num_layers: The number of layers. + num_units: The number of hidden units. + num_heads: The number of heads in the multi-head attention. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + attention_dropout: The probability to drop units from the attention. + relu_dropout: The probability to drop units from the ReLU activation in + the feed forward layer. + position_encoder: A :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + self_attention_type: Type of self attention, "scaled_dot" or "average" (case + insensitive). + + Raises: + ValueError: if :obj:`self_attention_type` is invalid. + """ + super(SelfAttentionDecoder, self).__init__() + self.num_layers = num_layers + self.num_units = num_units + self.num_heads = num_heads + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.position_encoder = position_encoder + self.self_attention_type = self_attention_type.lower() + if self.self_attention_type not in ('scaled_dot', 'average'): + raise ValueError('invalid attention type %s' + % self.self_attention_type) + if self.self_attention_type == 'average': + tf.logging.warning( + 'Support for average attention network is experimental ' + 'and may change in future versions.') + self.prenet_units = prenet_units + self.dense_units = dense_units + self.num_mels = num_mels + self.outputs_per_step = outputs_per_step + self.X_band_width = X_band_width + self.H_band_width = H_band_width + + @property + def output_size(self): + """Returns the decoder output size.""" + return self.num_units + + @property + def support_alignment_history(self): + return True + + @property + def support_multi_source(self): + return True + + def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1): + cache = {} + + for layer in range(self.num_layers): + proj_cache_shape = [ + batch_size, self.num_heads, 0, self.num_units // self.num_heads + ] + layer_cache = {} + layer_cache['memory'] = [{ + 'memory_keys': + tf.zeros(proj_cache_shape, dtype=dtype), + 'memory_values': + tf.zeros(proj_cache_shape, dtype=dtype) + } for _ in range(num_sources)] + if self.self_attention_type == 'scaled_dot': + layer_cache['self_keys'] = tf.zeros( + proj_cache_shape, dtype=dtype) + layer_cache['self_values'] = tf.zeros( + proj_cache_shape, dtype=dtype) + elif self.self_attention_type == 'average': + layer_cache['prev_g'] = tf.zeros( + [batch_size, 1, self.num_units], dtype=dtype) + cache['layer_{}'.format(layer)] = layer_cache + + return cache + + def _init_attn(self, dtype=tf.float32): + attn = [] + for layer in range(self.num_layers): + attn.append(tf.TensorArray(tf.float32, size=0, dynamic_size=True)) + return attn + + def _self_attention_stack(self, + inputs, + sequence_length=None, + mode=True, + cache=None, + memory=None, + memory_sequence_length=None, + step=None): + + # [N, T_out, self.dense_units] or [N, 1, self.dense_units] + prenet_outputs = decoder_prenet(inputs, self.prenet_units, + self.dense_units, mode) + if step is None: + decoder_inputs = tf.concat( + [memory, prenet_outputs], + axis=-1) # [N, T_out, memory_size + self.dense_units] + else: + decoder_inputs = tf.concat( + [memory[:, step:step + 1, :], prenet_outputs], + axis=-1) # [N, 1, memory_size + self.dense_units] + decoder_inputs = tf.layers.dense( + decoder_inputs, units=self.dense_units) + + inputs = decoder_inputs + inputs *= self.num_units**0.5 + if self.position_encoder is not None: + inputs = self.position_encoder( + inputs, position=step + 1 if step is not None else None) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + + decoder_mask = None + memory_mask = None + # last_attention = None + + X_band_width_tmp = -1 + H_band_width_tmp = -1 + if self.X_band_width is not None: + X_band_width_tmp = tf.cast( + tf.cond( + tf.less(tf.shape(memory)[1], self.X_band_width), + lambda: -1, lambda: self.X_band_width), + dtype=tf.int64) + if self.H_band_width is not None: + H_band_width_tmp = tf.cast( + tf.cond( + tf.less(tf.shape(memory)[1], self.H_band_width), + lambda: -1, lambda: self.H_band_width), + dtype=tf.int64) + + if self.self_attention_type == 'scaled_dot': + if sequence_length is not None: + decoder_mask = transformer.build_future_mask( + sequence_length, + num_heads=self.num_heads, + maximum_length=tf.shape(inputs)[1], + band=X_band_width_tmp) # [N, 1, T_out, T_out] + elif self.self_attention_type == 'average': + if cache is None: + if sequence_length is None: + sequence_length = tf.fill([tf.shape(inputs)[0]], + tf.shape(inputs)[1]) + decoder_mask = transformer.cumulative_average_mask( + sequence_length, + maximum_length=tf.shape(inputs)[1], + dtype=inputs.dtype) + + if memory is not None and not tf.contrib.framework.nest.is_sequence( + memory): + memory = (memory, ) + if memory_sequence_length is not None: + if not tf.contrib.framework.nest.is_sequence( + memory_sequence_length): + memory_sequence_length = (memory_sequence_length, ) + if step is None: + memory_mask = [ + transformer.build_history_mask( + length, + num_heads=self.num_heads, + maximum_length=tf.shape(m)[1], + band=H_band_width_tmp) + for m, length in zip(memory, memory_sequence_length) + ] + else: + memory_mask = [ + transformer.build_history_mask( + length, + num_heads=self.num_heads, + maximum_length=tf.shape(m)[1], + band=H_band_width_tmp)[:, :, step:step + 1, :] + for m, length in zip(memory, memory_sequence_length) + ] + + # last_attention = None + attns_x = [] + attns_h = [] + for layer in range(self.num_layers): + layer_name = 'layer_{}'.format(layer) + layer_cache = cache[layer_name] if cache is not None else None + with tf.variable_scope(layer_name): + if memory is not None: + for i, (mem, mask) in enumerate(zip(memory, memory_mask)): + memory_cache = None + if layer_cache is not None: + memory_cache = layer_cache['memory'][i] + scope_name = 'multi_head_{}'.format(i) + if i == 0: + scope_name = 'multi_head' + with tf.variable_scope(scope_name): + encoded, attn_x, attn_h = transformer.multi_head_attention_PNCA( + self.num_heads, + transformer.norm(inputs), + mem, + mode, + num_units=self.num_units, + mask=decoder_mask, + mask_h=mask, + cache=layer_cache, + cache_h=memory_cache, + dropout=self.attention_dropout, + return_attention=True, + layer_name=layer_name, + X_band_width=self.X_band_width) + attns_x.append(attn_x) + attns_h.append(attn_h) + context = transformer.drop_and_add( + inputs, encoded, mode, dropout=self.dropout) + + with tf.variable_scope('ffn'): + transformed = transformer.feed_forward_ori( + transformer.norm(context), + self.ffn_inner_dim, + mode, + dropout=self.relu_dropout) + transformed = transformer.drop_and_add( + context, transformed, mode, dropout=self.dropout) + + inputs = transformed + + outputs = transformer.norm(inputs) + outputs = tf.layers.dense( + outputs, units=self.num_mels * self.outputs_per_step) + return outputs, attns_x, attns_h + + def decode_from_inputs(self, + inputs, + sequence_length, + initial_state=None, + mode=True, + memory=None, + memory_sequence_length=None): + outputs, attention_x, attention_h = self._self_attention_stack( + inputs, + sequence_length=sequence_length, + mode=mode, + memory=memory, + memory_sequence_length=memory_sequence_length) + return outputs, attention_x, attention_h + + def step_fn(self, + mode, + batch_size, + initial_state=None, + memory=None, + memory_sequence_length=None, + dtype=tf.float32): + if memory is None: + num_sources = 0 + elif tf.contrib.framework.nest.is_sequence(memory): + num_sources = len(memory) + else: + num_sources = 1 + cache = self._init_cache( + batch_size, dtype=dtype, num_sources=num_sources) + attention_x = self._init_attn(dtype=dtype) + attention_h = self._init_attn(dtype=dtype) + + def _fn(step, inputs, cache): + outputs, attention_x, attention_h = self._self_attention_stack( + inputs, + mode=mode, + cache=cache, + memory=memory, + memory_sequence_length=memory_sequence_length, + step=step) + attention_x_tmp = [] + for layer in range(len(attention_h)): + attention_x_tmp_l = tf.zeros_like(attention_h[layer]) + if self.X_band_width is not None: + pred = tf.less(step, self.X_band_width + 1) + attention_x_tmp_l_1 = tf.cond(pred, # yapf:disable + lambda: attention_x_tmp_l[:, :, :, :step + 1] + attention_x[layer], + lambda: tf.concat([ + attention_x_tmp_l[:, :, :, + :step - self.X_band_width], + attention_x_tmp_l[:, :, :, + step - self.X_band_width:step + 1] + + attention_x[layer]], + axis=-1)) # yapf:disable + attention_x_tmp_l_2 = attention_x_tmp_l[:, :, :, step + 1:] + attention_x_tmp.append( + tf.concat([attention_x_tmp_l_1, attention_x_tmp_l_2], + axis=-1)) + else: + attention_x_tmp_l_1 = attention_x_tmp_l[:, :, :, :step + 1] + attention_x_tmp_l_2 = attention_x_tmp_l[:, :, :, step + 1:] + attention_x_tmp.append( + tf.concat([ + attention_x_tmp_l_1 + attention_x[layer], + attention_x_tmp_l_2 + ], axis=-1)) # yapf:disable + attention_x = attention_x_tmp + return outputs, cache, attention_x, attention_h + + return _fn, cache, attention_x, attention_h + + def dynamic_decode_and_search(self, init_decoder_input, maximum_iterations, + mode, memory, memory_sequence_length): + batch_size = tf.shape(init_decoder_input)[0] + step_fn, init_cache, init_attn_x, init_attn_h = self.step_fn( + mode, + batch_size, + memory=memory, + memory_sequence_length=memory_sequence_length) + + outputs, attention_x, attention_h, cache = self.dynamic_decode( + step_fn, + init_decoder_input, + init_cache=init_cache, + init_attn_x=init_attn_x, + init_attn_h=init_attn_h, + maximum_iterations=maximum_iterations, + batch_size=batch_size) + return outputs, attention_x, attention_h + + def dynamic_decode_and_search_teacher_forcing(self, decoder_input, + maximum_iterations, mode, + memory, + memory_sequence_length): + batch_size = tf.shape(decoder_input)[0] + step_fn, init_cache, init_attn_x, init_attn_h = self.step_fn( + mode, + batch_size, + memory=memory, + memory_sequence_length=memory_sequence_length) + + outputs, attention_x, attention_h, cache = self.dynamic_decode_teacher_forcing( + step_fn, + decoder_input, + init_cache=init_cache, + init_attn_x=init_attn_x, + init_attn_h=init_attn_h, + maximum_iterations=maximum_iterations, + batch_size=batch_size) + return outputs, attention_x, attention_h + + def dynamic_decode(self, + step_fn, + init_decoder_input, + init_cache=None, + init_attn_x=None, + init_attn_h=None, + maximum_iterations=None, + batch_size=None): + + def _cond(step, cache, inputs, outputs, attention_x, attention_h): # pylint: disable=unused-argument + return tf.less(step, maximum_iterations) + + def _body(step, cache, inputs, outputs, attention_x, attention_h): + # output: [1, 1, num_mels * r] + # attn: [1, 1, T_out] + output, cache, attn_x, attn_h = step_fn( + step, inputs, cache) # outputs, cache, attention, attns + for layer in range(len(attention_x)): + attention_x[layer] = attention_x[layer].write( + step, tf.cast(attn_x[layer], tf.float32)) + + for layer in range(len(attention_h)): + attention_h[layer] = attention_h[layer].write( + step, tf.cast(attn_h[layer], tf.float32)) + + outputs = outputs.write(step, tf.cast(output, tf.float32)) + return step + 1, cache, output[:, :, -self. + num_mels:], outputs, attention_x, attention_h + + step = tf.constant(0, dtype=tf.int32) + outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True) + + _, cache, _, outputs, attention_x, attention_h = tf.while_loop( + _cond, + _body, + loop_vars=(step, init_cache, init_decoder_input, outputs, + init_attn_x, init_attn_h), + shape_invariants=(step.shape, + compat.nest.map_structure( + self._get_shape_invariants, init_cache), + compat.nest.map_structure( + self._get_shape_invariants, + init_decoder_input), tf.TensorShape(None), + compat.nest.map_structure( + self._get_shape_invariants, init_attn_x), + compat.nest.map_structure( + self._get_shape_invariants, init_attn_h)), + parallel_iterations=1, + back_prop=False, + maximum_iterations=maximum_iterations) + # element of outputs: [N, 1, num_mels * r] + outputs_stack = outputs.stack() # [T_out, N, 1, num_mels * r] + outputs_stack = tf.transpose( + outputs_stack, perm=[2, 1, 0, 3]) # [1, N, T_out, num_mels * r] + outputs_stack = tf.squeeze( + outputs_stack, axis=0) # [N, T_out, num_mels * r] + + attention_x_stack = [] + for layer in range(len(attention_x)): + attention_x_stack_tmp = attention_x[layer].stack( + ) # [T_out, N, H, 1, T_out] + attention_x_stack_tmp = tf.transpose( + attention_x_stack_tmp, perm=[3, 1, 2, 0, + 4]) # [1, N, H, T_out, T_out] + attention_x_stack_tmp = tf.squeeze( + attention_x_stack_tmp, axis=0) # [N, H, T_out, T_out] + attention_x_stack.append(attention_x_stack_tmp) + + attention_h_stack = [] + for layer in range(len(attention_h)): + attention_h_stack_tmp = attention_h[layer].stack( + ) # [T_out, N, H, 1, T_out] + attention_h_stack_tmp = tf.transpose( + attention_h_stack_tmp, perm=[3, 1, 2, 0, + 4]) # [1, N, H, T_out, T_out] + attention_h_stack_tmp = tf.squeeze( + attention_h_stack_tmp, axis=0) # [N, H, T_out, T_out] + attention_h_stack.append(attention_h_stack_tmp) + + return outputs_stack, attention_x_stack, attention_h_stack, cache + + def dynamic_decode_teacher_forcing(self, + step_fn, + decoder_input, + init_cache=None, + init_attn_x=None, + init_attn_h=None, + maximum_iterations=None, + batch_size=None): + + def _cond(step, cache, inputs, outputs, attention_x, attention_h): # pylint: disable=unused-argument + return tf.less(step, maximum_iterations) + + def _body(step, cache, inputs, outputs, attention_x, attention_h): + # output: [1, 1, num_mels * r] + # attn: [1, 1, T_out] + output, cache, attn_x, attn_h = step_fn( + step, inputs[:, step:step + 1, :], + cache) # outputs, cache, attention, attns + for layer in range(len(attention_x)): + attention_x[layer] = attention_x[layer].write( + step, tf.cast(attn_x[layer], tf.float32)) + + for layer in range(len(attention_h)): + attention_h[layer] = attention_h[layer].write( + step, tf.cast(attn_h[layer], tf.float32)) + outputs = outputs.write(step, tf.cast(output, tf.float32)) + return step + 1, cache, inputs, outputs, attention_x, attention_h + + step = tf.constant(0, dtype=tf.int32) + outputs = tf.TensorArray(tf.float32, size=0, dynamic_size=True) + + _, cache, _, outputs, attention_x, attention_h = tf.while_loop( + _cond, + _body, + loop_vars=(step, init_cache, decoder_input, outputs, init_attn_x, + init_attn_h), + shape_invariants=(step.shape, + compat.nest.map_structure( + self._get_shape_invariants, + init_cache), decoder_input.shape, + tf.TensorShape(None), + compat.nest.map_structure( + self._get_shape_invariants, init_attn_x), + compat.nest.map_structure( + self._get_shape_invariants, init_attn_h)), + parallel_iterations=1, + back_prop=False, + maximum_iterations=maximum_iterations) + # element of outputs: [N, 1, num_mels * r] + outputs_stack = outputs.stack() # [T_out, N, 1, num_mels * r] + outputs_stack = tf.transpose( + outputs_stack, perm=[2, 1, 0, 3]) # [1, N, T_out, num_mels * r] + outputs_stack = tf.squeeze( + outputs_stack, axis=0) # [N, T_out, num_mels * r] + + attention_x_stack = [] + for layer in range(len(attention_x)): + attention_x_stack_tmp = attention_x[layer].stack( + ) # [T_out, N, H, 1, T_out] + attention_x_stack_tmp = tf.transpose( + attention_x_stack_tmp, perm=[3, 1, 2, 0, + 4]) # [1, N, H, T_out, T_out] + attention_x_stack_tmp = tf.squeeze( + attention_x_stack_tmp, axis=0) # [N, H, T_out, T_out] + attention_x_stack.append(attention_x_stack_tmp) + + attention_h_stack = [] + for layer in range(len(attention_h)): + attention_h_stack_tmp = attention_h[layer].stack( + ) # [T_out, N, H, 1, T_out] + attention_h_stack_tmp = tf.transpose( + attention_h_stack_tmp, perm=[3, 1, 2, 0, + 4]) # [1, N, H, T_out, T_out] + attention_h_stack_tmp = tf.squeeze( + attention_h_stack_tmp, axis=0) # [N, H, T_out, T_out] + attention_h_stack.append(attention_h_stack_tmp) + + return outputs_stack, attention_x_stack, attention_h_stack, cache + + def _get_shape_invariants(self, tensor): + """Returns the shape of the tensor but sets middle dims to None.""" + if isinstance(tensor, tf.TensorArray): + shape = None + else: + shape = tensor.shape.as_list() + for i in range(1, len(shape) - 1): + shape[i] = None + return tf.TensorShape(shape) + + +class SelfAttentionDecoderOri(): + """Decoder using self-attention as described in + https://arxiv.org/abs/1706.03762. + """ + + def __init__(self, + num_layers, + num_units=512, + num_heads=8, + ffn_inner_dim=2048, + dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + position_encoder=SinusoidalPositionEncoder(), + self_attention_type='scaled_dot'): + """Initializes the parameters of the decoder. + + Args: + num_layers: The number of layers. + num_units: The number of hidden units. + num_heads: The number of heads in the multi-head attention. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + attention_dropout: The probability to drop units from the attention. + relu_dropout: The probability to drop units from the ReLU activation in + the feed forward layer. + position_encoder: A :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + self_attention_type: Type of self attention, "scaled_dot" or "average" (case + insensitive). + + Raises: + ValueError: if :obj:`self_attention_type` is invalid. + """ + super(SelfAttentionDecoderOri, self).__init__() + self.num_layers = num_layers + self.num_units = num_units + self.num_heads = num_heads + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.position_encoder = position_encoder + self.self_attention_type = self_attention_type.lower() + if self.self_attention_type not in ('scaled_dot', 'average'): + raise ValueError('invalid attention type %s' + % self.self_attention_type) + if self.self_attention_type == 'average': + tf.logging.warning( + 'Support for average attention network is experimental ' + 'and may change in future versions.') + + @property + def output_size(self): + """Returns the decoder output size.""" + return self.num_units + + @property + def support_alignment_history(self): + return True + + @property + def support_multi_source(self): + return True + + def _init_cache(self, batch_size, dtype=tf.float32, num_sources=1): + cache = {} + + for layer in range(self.num_layers): + proj_cache_shape = [ + batch_size, self.num_heads, 0, self.num_units // self.num_heads + ] + layer_cache = {} + layer_cache['memory'] = [{ + 'memory_keys': + tf.zeros(proj_cache_shape, dtype=dtype), + 'memory_values': + tf.zeros(proj_cache_shape, dtype=dtype) + } for _ in range(num_sources)] + if self.self_attention_type == 'scaled_dot': + layer_cache['self_keys'] = tf.zeros( + proj_cache_shape, dtype=dtype) + layer_cache['self_values'] = tf.zeros( + proj_cache_shape, dtype=dtype) + elif self.self_attention_type == 'average': + layer_cache['prev_g'] = tf.zeros( + [batch_size, 1, self.num_units], dtype=dtype) + cache['layer_{}'.format(layer)] = layer_cache + + return cache + + def _self_attention_stack(self, + inputs, + sequence_length=None, + mode=True, + cache=None, + memory=None, + memory_sequence_length=None, + step=None): + inputs *= self.num_units**0.5 + if self.position_encoder is not None: + inputs = self.position_encoder( + inputs, position=step + 1 if step is not None else None) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + + decoder_mask = None + memory_mask = None + last_attention = None + + if self.self_attention_type == 'scaled_dot': + if sequence_length is not None: + decoder_mask = transformer.build_future_mask( + sequence_length, + num_heads=self.num_heads, + maximum_length=tf.shape(inputs)[1]) + elif self.self_attention_type == 'average': + if cache is None: + if sequence_length is None: + sequence_length = tf.fill([tf.shape(inputs)[0]], + tf.shape(inputs)[1]) + decoder_mask = transformer.cumulative_average_mask( + sequence_length, + maximum_length=tf.shape(inputs)[1], + dtype=inputs.dtype) + + if memory is not None and not tf.contrib.framework.nest.is_sequence( + memory): + memory = (memory, ) + if memory_sequence_length is not None: + if not tf.contrib.framework.nest.is_sequence( + memory_sequence_length): + memory_sequence_length = (memory_sequence_length, ) + memory_mask = [ + transformer.build_sequence_mask( + length, + num_heads=self.num_heads, + maximum_length=tf.shape(m)[1]) + for m, length in zip(memory, memory_sequence_length) + ] + + for layer in range(self.num_layers): + layer_name = 'layer_{}'.format(layer) + layer_cache = cache[layer_name] if cache is not None else None + with tf.variable_scope(layer_name): + if self.self_attention_type == 'scaled_dot': + with tf.variable_scope('masked_multi_head'): + encoded = transformer.multi_head_attention( + self.num_heads, + transformer.norm(inputs), + None, + mode, + num_units=self.num_units, + mask=decoder_mask, + cache=layer_cache, + dropout=self.attention_dropout) + last_context = transformer.drop_and_add( + inputs, encoded, mode, dropout=self.dropout) + elif self.self_attention_type == 'average': + with tf.variable_scope('average_attention'): + # Cumulative average. + x = transformer.norm(inputs) + y = transformer.cumulative_average( + x, + decoder_mask if cache is None else step, + cache=layer_cache) + # FFN. + y = transformer.feed_forward( + y, + self.ffn_inner_dim, + mode, + dropout=self.relu_dropout) + # Gating layer. + z = tf.layers.dense( + tf.concat([x, y], -1), self.num_units * 2) + i, f = tf.split(z, 2, axis=-1) + y = tf.sigmoid(i) * x + tf.sigmoid(f) * y + last_context = transformer.drop_and_add( + inputs, y, mode, dropout=self.dropout) + + if memory is not None: + for i, (mem, mask) in enumerate(zip(memory, memory_mask)): + memory_cache = layer_cache['memory'][i] if layer_cache is not None else None # yapf:disable + with tf.variable_scope('multi_head' if i + == 0 else 'multi_head_%d' % i): # yapf:disable + context, last_attention = transformer.multi_head_attention( + self.num_heads, + transformer.norm(last_context), + mem, + mode, + mask=mask, + cache=memory_cache, + dropout=self.attention_dropout, + return_attention=True) + last_context = transformer.drop_and_add( + last_context, + context, + mode, + dropout=self.dropout) + if i > 0: # Do not return attention in case of multi source. + last_attention = None + + with tf.variable_scope('ffn'): + transformed = transformer.feed_forward_ori( + transformer.norm(last_context), + self.ffn_inner_dim, + mode, + dropout=self.relu_dropout) + transformed = transformer.drop_and_add( + last_context, transformed, mode, dropout=self.dropout) + + inputs = transformed + + if last_attention is not None: + # The first head of the last layer is returned. + first_head_attention = last_attention[:, 0] + else: + first_head_attention = None + + outputs = transformer.norm(inputs) + return outputs, first_head_attention + + def decode_from_inputs(self, + inputs, + sequence_length, + initial_state=None, + mode=True, + memory=None, + memory_sequence_length=None): + outputs, attention = self._self_attention_stack( + inputs, + sequence_length=sequence_length, + mode=mode, + memory=memory, + memory_sequence_length=memory_sequence_length) + return outputs, None, attention + + def step_fn(self, + mode, + batch_size, + initial_state=None, + memory=None, + memory_sequence_length=None, + dtype=tf.float32): + if memory is None: + num_sources = 0 + elif tf.contrib.framework.nest.is_sequence(memory): + num_sources = len(memory) + else: + num_sources = 1 + cache = self._init_cache( + batch_size, dtype=dtype, num_sources=num_sources) + + def _fn(step, inputs, cache, mode): + inputs = tf.expand_dims(inputs, 1) + outputs, attention = self._self_attention_stack( + inputs, + mode=mode, + cache=cache, + memory=memory, + memory_sequence_length=memory_sequence_length, + step=step) + outputs = tf.squeeze(outputs, axis=1) + if attention is not None: + attention = tf.squeeze(attention, axis=1) + return outputs, cache, attention + + return _fn, cache diff --git a/modelscope/models/audio/tts/am/models/self_attention_encoder.py b/modelscope/models/audio/tts/am/models/self_attention_encoder.py new file mode 100755 index 00000000..ce4193dc --- /dev/null +++ b/modelscope/models/audio/tts/am/models/self_attention_encoder.py @@ -0,0 +1,182 @@ +"""Define the self-attention encoder.""" + +import tensorflow as tf + +from . import transformer +from .position import SinusoidalPositionEncoder + + +class SelfAttentionEncoder(): + """Encoder using self-attention as described in + https://arxiv.org/abs/1706.03762. + """ + + def __init__(self, + num_layers, + num_units=512, + num_heads=8, + ffn_inner_dim=2048, + dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + position_encoder=SinusoidalPositionEncoder()): + """Initializes the parameters of the encoder. + + Args: + num_layers: The number of layers. + num_units: The number of hidden units. + num_heads: The number of heads in the multi-head attention. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + attention_dropout: The probability to drop units from the attention. + relu_dropout: The probability to drop units from the ReLU activation in + the feed forward layer. + position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + """ + super(SelfAttentionEncoder, self).__init__() + self.num_layers = num_layers + self.num_units = num_units + self.num_heads = num_heads + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.position_encoder = position_encoder + + def encode(self, inputs, sequence_length=None, mode=True): + inputs *= self.num_units**0.5 + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + mask = transformer.build_sequence_mask( + sequence_length, + num_heads=self.num_heads, + maximum_length=tf.shape(inputs)[1]) + + mask_FF = tf.squeeze( + transformer.build_sequence_mask( + sequence_length, maximum_length=tf.shape(inputs)[1]), + axis=1) + + state = () + + attns = [] + for layer in range(self.num_layers): + with tf.variable_scope('layer_{}'.format(layer)): + with tf.variable_scope('multi_head'): + context, attn = transformer.multi_head_attention( + self.num_heads, + transformer.norm(inputs), + None, + mode, + num_units=self.num_units, + mask=mask, + dropout=self.attention_dropout, + return_attention=True) + attns.append(attn) + context = transformer.drop_and_add( + inputs, context, mode, dropout=self.dropout) + + with tf.variable_scope('ffn'): + transformed = transformer.feed_forward( + transformer.norm(context), + self.ffn_inner_dim, + mode, + dropout=self.relu_dropout, + mask=mask_FF) + transformed = transformer.drop_and_add( + context, transformed, mode, dropout=self.dropout) + + inputs = transformed + state += (tf.reduce_mean(inputs, axis=1), ) + + outputs = transformer.norm(inputs) + return (outputs, state, sequence_length, attns) + + +class SelfAttentionEncoderOri(): + """Encoder using self-attention as described in + https://arxiv.org/abs/1706.03762. + """ + + def __init__(self, + num_layers, + num_units=512, + num_heads=8, + ffn_inner_dim=2048, + dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + position_encoder=SinusoidalPositionEncoder()): + """Initializes the parameters of the encoder. + + Args: + num_layers: The number of layers. + num_units: The number of hidden units. + num_heads: The number of heads in the multi-head attention. + ffn_inner_dim: The number of units of the inner linear transformation + in the feed forward layer. + dropout: The probability to drop units from the outputs. + attention_dropout: The probability to drop units from the attention. + relu_dropout: The probability to drop units from the ReLU activation in + the feed forward layer. + position_encoder: The :class:`opennmt.layers.position.PositionEncoder` to + apply on inputs or ``None``. + """ + super(SelfAttentionEncoderOri, self).__init__() + self.num_layers = num_layers + self.num_units = num_units + self.num_heads = num_heads + self.ffn_inner_dim = ffn_inner_dim + self.dropout = dropout + self.attention_dropout = attention_dropout + self.relu_dropout = relu_dropout + self.position_encoder = position_encoder + + def encode(self, inputs, sequence_length=None, mode=True): + inputs *= self.num_units**0.5 + if self.position_encoder is not None: + inputs = self.position_encoder(inputs) + + inputs = tf.layers.dropout(inputs, rate=self.dropout, training=mode) + mask = transformer.build_sequence_mask( + sequence_length, + num_heads=self.num_heads, + maximum_length=tf.shape(inputs)[1]) # [N, 1, 1, T_out] + + state = () + + attns = [] + for layer in range(self.num_layers): + with tf.variable_scope('layer_{}'.format(layer)): + with tf.variable_scope('multi_head'): + context, attn = transformer.multi_head_attention( + self.num_heads, + transformer.norm(inputs), + None, + mode, + num_units=self.num_units, + mask=mask, + dropout=self.attention_dropout, + return_attention=True) + attns.append(attn) + context = transformer.drop_and_add( + inputs, context, mode, dropout=self.dropout) + + with tf.variable_scope('ffn'): + transformed = transformer.feed_forward_ori( + transformer.norm(context), + self.ffn_inner_dim, + mode, + dropout=self.relu_dropout) + transformed = transformer.drop_and_add( + context, transformed, mode, dropout=self.dropout) + + inputs = transformed + state += (tf.reduce_mean(inputs, axis=1), ) + + outputs = transformer.norm(inputs) + return (outputs, state, sequence_length, attns) diff --git a/modelscope/models/audio/tts/am/models/transformer.py b/modelscope/models/audio/tts/am/models/transformer.py new file mode 100755 index 00000000..a9f0bedc --- /dev/null +++ b/modelscope/models/audio/tts/am/models/transformer.py @@ -0,0 +1,1157 @@ +"""Define layers related to the Google's Transformer model.""" + +import tensorflow as tf + +from . import compat, fsmn + + +def tile_sequence_length(sequence_length, num_heads): + """Tiles lengths :obj:`num_heads` times. + + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + + Returns: + A ``tf.Tensor`` where each length is replicated :obj:`num_heads` times. + """ + sequence_length = tf.tile(sequence_length, [num_heads]) + sequence_length = tf.reshape(sequence_length, [num_heads, -1]) + sequence_length = tf.transpose(sequence_length, perm=[1, 0]) + sequence_length = tf.reshape(sequence_length, [-1]) + return sequence_length + + +def build_sequence_mask(sequence_length, + num_heads=None, + maximum_length=None, + dtype=tf.float32): + """Builds the dot product mask. + + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, 1, 1, max_length]``. + """ + mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + mask = tf.expand_dims(mask, axis=1) + if num_heads is not None: + mask = tf.expand_dims(mask, axis=1) + return mask + + +def build_sequence_mask_window(sequence_length, + left_window_size=-1, + right_window_size=-1, + num_heads=None, + maximum_length=None, + dtype=tf.float32): + """Builds the dot product mask. + + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, 1, 1, max_length]``. + """ + sequence_mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + mask = _window_mask( + sequence_length, + left_window_size=left_window_size, + right_window_size=right_window_size, + maximum_length=maximum_length, + dtype=dtype) + mask *= tf.expand_dims(sequence_mask, axis=1) + if num_heads is not None: + mask = tf.expand_dims(mask, axis=1) + return mask + + +def _lower_triangle_mask(sequence_length, + maximum_length=None, + dtype=tf.float32, + band=-1): + batch_size = tf.shape(sequence_length)[0] + if maximum_length is None: + maximum_length = tf.reduce_max(sequence_length) + mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) + mask = compat.tf_compat( + v2='linalg.band_part', v1='matrix_band_part')(mask, band, 0) + return mask + + +def _higher_triangle_mask(sequence_length, + maximum_length=None, + dtype=tf.float32, + band=-1): + batch_size = tf.shape(sequence_length)[0] + if maximum_length is None: + maximum_length = tf.reduce_max(sequence_length) + mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) + mask = compat.tf_compat( + v2='linalg.band_part', v1='matrix_band_part')(mask, 0, band) + return mask + + +def _window_mask(sequence_length, + left_window_size=-1, + right_window_size=-1, + maximum_length=None, + dtype=tf.float32): + batch_size = tf.shape(sequence_length)[0] + if maximum_length is None: + maximum_length = tf.reduce_max(sequence_length) + mask = tf.ones([batch_size, maximum_length, maximum_length], dtype=dtype) + left_window_size = tf.minimum( + tf.cast(left_window_size, tf.int64), + tf.cast(maximum_length - 1, tf.int64)) + right_window_size = tf.minimum( + tf.cast(right_window_size, tf.int64), + tf.cast(maximum_length - 1, tf.int64)) + mask = tf.matrix_band_part(mask, left_window_size, right_window_size) + return mask + + +def build_future_mask(sequence_length, + num_heads=None, + maximum_length=None, + dtype=tf.float32, + band=-1): + """Builds the dot product mask for future positions. + + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, 1, max_length, max_length]``. + """ + sequence_mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + mask = _lower_triangle_mask( + sequence_length, maximum_length=maximum_length, dtype=dtype, band=band) + mask *= tf.expand_dims(sequence_mask, axis=1) + if num_heads is not None: + mask = tf.expand_dims(mask, axis=1) + return mask + + +def build_history_mask(sequence_length, + num_heads=None, + maximum_length=None, + dtype=tf.float32, + band=-1): + """Builds the dot product mask for future positions. + + Args: + sequence_length: The sequence length. + num_heads: The number of heads. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A broadcastable ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, 1, max_length, max_length]``. + """ + sequence_mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + mask = _higher_triangle_mask( + sequence_length, maximum_length=maximum_length, dtype=dtype, band=band) + mask *= tf.expand_dims(sequence_mask, axis=1) + if num_heads is not None: + mask = tf.expand_dims(mask, axis=1) + return mask + + +def cumulative_average_mask(sequence_length, + maximum_length=None, + dtype=tf.float32): + """Builds the mask to compute the cumulative average as described in + https://arxiv.org/abs/1805.00631. + + Args: + sequence_length: The sequence length. + maximum_length: Optional size of the returned time dimension. Otherwise + it is the maximum of :obj:`sequence_length`. + dtype: The type of the mask tensor. + + Returns: + A ``tf.Tensor`` of type :obj:`dtype` and shape + ``[batch_size, max_length, max_length]``. + """ + sequence_mask = tf.sequence_mask( + sequence_length, maxlen=maximum_length, dtype=dtype) + mask = _lower_triangle_mask( + sequence_length, maximum_length=maximum_length, dtype=dtype) + mask *= tf.expand_dims(sequence_mask, axis=2) + weight = tf.range(1, tf.cast(tf.shape(mask)[1] + 1, dtype), dtype=dtype) + mask /= tf.expand_dims(weight, 1) + return mask + + +def cumulative_average(inputs, mask_or_step, cache=None): + """Computes the cumulative average as described in + https://arxiv.org/abs/1805.00631. + + Args: + inputs: The sequence to average. A tensor of shape :math:`[B, T, D]`. + mask_or_step: If :obj:`cache` is set, this is assumed to be the current step + of the dynamic decoding. Otherwise, it is the mask matrix used to compute + the cumulative average. + cache: A dictionnary containing the cumulative average of the previous step. + + Returns: + The cumulative average, a tensor of the same shape and type as :obj:`inputs`. + """ + if cache is not None: + step = tf.cast(mask_or_step, inputs.dtype) + aa = (inputs + step * cache['prev_g']) / (step + 1.0) + cache['prev_g'] = aa + return aa + else: + mask = mask_or_step + return tf.matmul(mask, inputs) + + +def fused_projection(inputs, num_units, num_outputs=1): + """Projects the same input into multiple output spaces. + + Args: + inputs: The inputs to project. + num_units: The number of output units of each space. + num_outputs: The number of output spaces. + + Returns: + :obj:`num_outputs` ``tf.Tensor`` of depth :obj:`num_units`. + """ + return tf.split( + tf.layers.conv1d(inputs, num_units * num_outputs, 1), + num_outputs, + axis=2) + + +def split_heads(inputs, num_heads): + """Splits a tensor in depth. + + Args: + inputs: A ``tf.Tensor`` of shape :math:`[B, T, D]`. + num_heads: The number of heads :math:`H`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B, H, T, D / H]`. + """ + static_shape = inputs.get_shape().as_list() + depth = static_shape[-1] + outputs = tf.reshape(inputs, [ + tf.shape(inputs)[0], + tf.shape(inputs)[1], num_heads, depth // num_heads + ]) + outputs = tf.transpose(outputs, perm=[0, 2, 1, 3]) + return outputs + + +def combine_heads(inputs): + """Concatenates heads. + + Args: + inputs: A ``tf.Tensor`` of shape :math:`[B, H, T, D]`. + + Returns: + A ``tf.Tensor`` of shape :math:`[B, T, D * H]`. + """ + static_shape = inputs.get_shape().as_list() + depth = static_shape[-1] + num_heads = static_shape[1] + outputs = tf.transpose(inputs, perm=[0, 2, 1, 3]) + outputs = tf.reshape( + outputs, + [tf.shape(outputs)[0], + tf.shape(outputs)[1], depth * num_heads]) + return outputs + + +def dot_product_attention(queries, keys, values, mode, mask=None, dropout=0.0): + """Computes the dot product attention. + + Args: + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + keys: The sequence use to calculate attention scores. A tensor of shape + :math:`[B, T_2, ...]`. + values: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + mode: A ``tf.estimator.ModeKeys`` mode. + mask: A ``tf.Tensor`` applied to the dot product. + dropout: The probability to drop units from the inputs. + + Returns: + A tuple ``(context vector, attention vector)``. + """ + dot = tf.matmul(queries, keys, transpose_b=True) + + if mask is not None: + dot = tf.cast( + tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), + dot.dtype) + + softmax = tf.nn.softmax(tf.cast(dot, tf.float32)) + attn = tf.cast(softmax, dot.dtype) + drop_attn = tf.layers.dropout(attn, rate=dropout, training=mode) + + context = tf.matmul(drop_attn, values) + + return context, attn + + +def dot_product_attention_wpa(num_heads, + queries, + keys, + values, + mode, + attention_left_window=-1, + attention_right_window=0, + mask=None, + max_id_cache=None, + mono=False, + peak_delay=-1, + dropout=0.0): + """ + Computes the dot product attention. + Args: + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + keys: The sequence use to calculate attention scores. A tensor of shape + :math:`[B, T_2, ...]`. + values: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + mode: A ``tf.estimator.ModeKeys`` mode. + mask: A ``tf.Tensor`` applied to the dot product. + dropout: The probability to drop units from the inputs. + + Returns: + A tuple ``(context vector, attention vector)``. + """ + # Dot product between queries and keys. + dot = tf.matmul(queries, keys, transpose_b=True) + depth = tf.shape(dot)[-1] + if mask is not None: + dot = tf.cast( + tf.cast(dot, tf.float32) * mask + ((1.0 - mask) * tf.float32.min), + dot.dtype) + # wpa + max_id = tf.math.argmax(input=dot, axis=-1) + # peak delay + if peak_delay > 0: + if max_id_cache is not None: + M = tf.cast(max_id_cache['pre_max_id'], dtype=max_id.dtype) + inputs_len = tf.math.minimum( + M + peak_delay, tf.cast(depth - 1, dtype=max_id.dtype)) + delay_mask = tf.sequence_mask( + inputs_len, maxlen=depth, dtype=tf.float32) + dot = tf.cast( + tf.cast(dot, tf.float32) * delay_mask + + ((1.0 - delay_mask) * tf.float32.min), dot.dtype) # yapf:disable + max_id = tf.math.argmax(input=dot, axis=-1) + # mono + if mono: + if max_id_cache is None: + d = tf.shape(max_id)[-1] + tmp_max_id = tf.reshape(max_id, [-1, num_heads, d]) + tmp_max_id = tf.slice( + tmp_max_id, [0, 0, 0], + [tf.shape(tmp_max_id)[0], + tf.shape(tmp_max_id)[1], d - 1]) + zeros = tf.zeros( + shape=(tf.shape(tmp_max_id)[0], tf.shape(tmp_max_id)[1], 1), + dtype=max_id.dtype) + tmp_max_id = tf.concat([zeros, tmp_max_id], axis=-1) + mask1 = tf.sequence_mask( + tmp_max_id, maxlen=depth, dtype=tf.float32) + dot = tf.cast( + tf.cast(dot, tf.float32) + * (1.0 - mask1) + mask1 * tf.float32.min, dot.dtype) # yapf:disable + max_id = tf.math.argmax(input=dot, axis=-1) + else: + # eval + tmp_max_id = tf.reshape(max_id, [-1, num_heads, 1]) + max_id_cache['pre_max_id'] = tmp_max_id + # right_mask + right_offset = tf.constant(attention_right_window, dtype=max_id.dtype) + right_len = tf.math.minimum(max_id + right_offset, + tf.cast(depth - 1, dtype=max_id.dtype)) + right_mask = tf.sequence_mask(right_len, maxlen=depth, dtype=tf.float32) + dot = tf.cast( + tf.cast(dot, tf.float32) * right_mask + + ((1.0 - right_mask) * tf.float32.min), dot.dtype) # yapf:disable + # left_mask + if attention_left_window > 0: + left_offset = tf.constant(attention_left_window, dtype=max_id.dtype) + left_len = tf.math.maximum(max_id - left_offset, + tf.cast(0, dtype=max_id.dtype)) + left_mask = tf.sequence_mask(left_len, maxlen=depth, dtype=tf.float32) + dot = tf.cast( + tf.cast(dot, tf.float32) * (1.0 - left_mask) + + (left_mask * tf.float32.min), dot.dtype) # yapf:disable + # Compute attention weights. + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) + drop_attn = tf.layers.dropout(attn, rate=dropout, training=mode) + + # Compute attention context. + context = tf.matmul(drop_attn, values) + + return context, attn + + +def multi_head_attention(num_heads, + queries, + memory, + mode, + num_units=None, + mask=None, + cache=None, + dropout=0.0, + return_attention=False): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + + Args: + num_heads: The number of attention heads. + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mode: A ``tf.estimator.ModeKeys`` mode. + num_units: The number of hidden units. If not set, it is set to the input + dimension. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + dropout: The probability to drop units from the inputs. + return_attention: Return the attention head probabilities in addition to the + context. + + Returns: + The concatenated attention context of each head and the attention + probabilities (if :obj:`return_attention` is set). + """ + num_units = num_units or queries.get_shape().as_list()[-1] + + if num_units % num_heads != 0: + raise ValueError('Multi head attention requires that num_units is a' + ' multiple of {}'.format(num_heads)) + + if memory is None: + queries, keys, values = fused_projection( + queries, num_units, num_outputs=3) + + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + if cache is not None: + keys = tf.concat([cache['self_keys'], keys], axis=2) + values = tf.concat([cache['self_values'], values], axis=2) + cache['self_keys'] = keys + cache['self_values'] = values + else: + queries = tf.layers.conv1d(queries, num_units, 1) + + if cache is not None: + + def _project_and_split(): + k, v = fused_projection(memory, num_units, num_outputs=2) + return split_heads(k, num_heads), split_heads(v, num_heads) + + keys, values = tf.cond( + tf.equal(tf.shape(cache['memory_keys'])[2], 0), + true_fn=_project_and_split, + false_fn=lambda: + (cache['memory_keys'], cache['memory_values'])) + cache['memory_keys'] = keys + cache['memory_values'] = values + else: + keys, values = fused_projection(memory, num_units, num_outputs=2) + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + queries = split_heads(queries, num_heads) + queries *= (num_units // num_heads)**-0.5 + + heads, attn = dot_product_attention( + queries, keys, values, mode, mask=mask, dropout=dropout) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = tf.layers.conv1d(combined, num_units, 1) + + if not return_attention: + return outputs + return outputs, attn + + +def multi_head_attention_PNCA(num_heads, + queries, + memory, + mode, + num_units=None, + mask=None, + mask_h=None, + cache=None, + cache_h=None, + dropout=0.0, + return_attention=False, + X_band_width=None, + layer_name='multi_head'): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + + Args: + num_heads: The number of attention heads. + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mode: A ``tf.estimator.ModeKeys`` mode. + num_units: The number of hidden units. If not set, it is set to the input + dimension. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + dropout: The probability to drop units from the inputs. + return_attention: Return the attention head probabilities in addition to the + context. + + Returns: + The concatenated attention context of each head and the attention + probabilities (if :obj:`return_attention` is set). + """ + num_units = num_units or queries.get_shape().as_list()[-1] + + if num_units % num_heads != 0: + raise ValueError('Multi head attention requires that num_units is a' + ' multiple of {}'.format(num_heads)) + + # X + queries, keys, values = fused_projection(queries, num_units, num_outputs=3) + + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + if cache is not None: + keys = tf.concat([cache['self_keys'], keys], axis=2) + values = tf.concat([cache['self_values'], values], axis=2) + if X_band_width is not None: + keys_band = tf.cond( + tf.less(X_band_width, 0), lambda: keys, lambda: tf.cond( + tf.less(tf.shape(keys)[2], X_band_width), lambda: keys, + lambda: keys[:, :, -X_band_width:, :]) + ) # not support X_band_width == 0 + values_band = tf.cond( + tf.less(X_band_width, 0), lambda: values, lambda: tf.cond( + tf.less(tf.shape(values)[2], X_band_width), lambda: values, + lambda: values[:, :, -X_band_width:, :])) + cache['self_keys'] = keys_band + cache['self_values'] = values_band + else: + cache['self_keys'] = keys + cache['self_values'] = values + + queries = split_heads(queries, num_heads) + queries *= (num_units // num_heads)**-0.5 + + heads, attn = dot_product_attention( + queries, keys, values, mode, mask=mask, dropout=dropout) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = tf.layers.conv1d(combined, num_units, 1) + + # H + if cache_h is not None: + + def _project_and_split(): + k, v = fused_projection(memory, num_units, num_outputs=2) + return split_heads(k, num_heads), split_heads(v, num_heads) + + keys_h, values_h = tf.cond( + tf.equal(tf.shape(cache_h['memory_keys'])[2], 0), + true_fn=_project_and_split, + false_fn=lambda: + (cache_h['memory_keys'], cache_h['memory_values'])) + cache_h['memory_keys'] = keys_h + cache_h['memory_values'] = values_h + else: + keys_h, values_h = fused_projection(memory, num_units, num_outputs=2) + keys_h = split_heads(keys_h, num_heads) + values_h = split_heads(values_h, num_heads) + + heads_h, attn_h = dot_product_attention( + queries, keys_h, values_h, mode, mask=mask_h, dropout=dropout) + + # Concatenate all heads output. + combined_h = combine_heads(heads_h) + outputs_h = tf.layers.conv1d(combined_h, num_units, 1) + + # ADD + outputs = outputs + outputs_h + + # RETURN + return outputs, attn, attn_h + + +def multi_head_attention_memory(num_heads, + queries, + memory, + mode, + num_memory=None, + num_units=None, + mask=None, + cache=None, + dropout=0.0, + return_attention=False): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + + Args: + num_heads: The number of attention heads. + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mode: A ``tf.estimator.ModeKeys`` mode. + num_units: The number of hidden units. If not set, it is set to the input + dimension. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + dropout: The probability to drop units from the inputs. + return_attention: Return the attention head probabilities in addition to the + context. + + Returns: + The concatenated attention context of each head and the attention + probabilities (if :obj:`return_attention` is set). + """ + num_units = num_units or queries.get_shape().as_list()[-1] + + if num_units % num_heads != 0: + raise ValueError('Multi head attention requires that num_units is a' + ' multiple of {}'.format(num_heads)) + + # PERSISTENT MEMORY + # key memory + if num_memory is not None: + key_m = tf.get_variable( + 'key_m', + shape=[num_memory, num_units], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + # value memory + value_m = tf.get_variable( + 'value_m', + shape=[num_memory, num_units], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + if memory is None: + queries, keys, values = fused_projection( + queries, num_units, num_outputs=3) + + # concat memory + if num_memory is not None: + key_m_expand = tf.tile( + tf.expand_dims(key_m, 0), [tf.shape(keys)[0], 1, 1]) + value_m_expand = tf.tile( + tf.expand_dims(value_m, 0), [tf.shape(values)[0], 1, 1]) + keys = tf.concat([key_m_expand, keys], axis=1) + values = tf.concat([value_m_expand, values], axis=1) + + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + if cache is not None: + keys = tf.concat([cache['self_keys'], keys], axis=2) + values = tf.concat([cache['self_values'], values], axis=2) + cache['self_keys'] = keys + cache['self_values'] = values + else: + queries = tf.layers.conv1d(queries, num_units, 1) + + if cache is not None: + + def _project_and_split(): + k, v = fused_projection(memory, num_units, num_outputs=2) + return split_heads(k, num_heads), split_heads(v, num_heads) + + keys, values = tf.cond( + tf.equal(tf.shape(cache['memory_keys'])[2], 0), + true_fn=_project_and_split, + false_fn=lambda: + (cache['memory_keys'], cache['memory_values'])) + cache['memory_keys'] = keys + cache['memory_values'] = values + else: + keys, values = fused_projection(memory, num_units, num_outputs=2) + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + queries = split_heads(queries, num_heads) + queries *= (num_units // num_heads)**-0.5 + + heads, attn = dot_product_attention( + queries, keys, values, mode, mask=mask, dropout=dropout) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = tf.layers.conv1d(combined, num_units, 1) + + if not return_attention: + return outputs + return outputs, attn + + +def Ci_Cd_Memory(num_heads, + queries, + mode, + filter_size=None, + num_memory=None, + num_units=None, + fsmn_mask=None, + san_mask=None, + cache=None, + shift=None, + dropout=0.0, + return_attention=False): + """ + Args: + num_heads: The number of attention heads. + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mode: A ``tf.estimator.ModeKeys`` mode. + num_units: The number of hidden units. If not set, it is set to the input + dimension. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + dropout: The probability to drop units from the inputs. + return_attention: Return the attention head probabilities in addition to the + context. + + Returns: + The concatenated attention context of each head and the attention + probabilities (if :obj:`return_attention` is set). + """ + num_units = num_units or queries.get_shape().as_list()[-1] + + if num_units % num_heads != 0: + raise ValueError('Multi head attention requires that num_units is a' + ' multiple of {}'.format(num_heads)) + # PERSISTENT MEMORY + if num_memory is not None: + key_m = tf.get_variable( + 'key_m', + shape=[num_memory, num_units], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + value_m = tf.get_variable( + 'value_m', + shape=[num_memory, num_units], + initializer=tf.glorot_uniform_initializer(), + dtype=tf.float32) + + queries, keys, values = fused_projection(queries, num_units, num_outputs=3) + # fsmn memory block + if shift is not None: + # encoder + fsmn_memory = fsmn.MemoryBlockV2( + values, + filter_size, + mode, + shift=shift, + mask=fsmn_mask, + dropout=dropout) + else: + # decoder + fsmn_memory = fsmn.UniMemoryBlock( + values, + filter_size, + mode, + cache=cache, + mask=fsmn_mask, + dropout=dropout) + + # concat persistent memory + if num_memory is not None: + key_m_expand = tf.tile( + tf.expand_dims(key_m, 0), [tf.shape(keys)[0], 1, 1]) + value_m_expand = tf.tile( + tf.expand_dims(value_m, 0), [tf.shape(values)[0], 1, 1]) + keys = tf.concat([key_m_expand, keys], axis=1) + values = tf.concat([value_m_expand, values], axis=1) + + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + if cache is not None: + keys = tf.concat([cache['self_keys'], keys], axis=2) + values = tf.concat([cache['self_values'], values], axis=2) + cache['self_keys'] = keys + cache['self_values'] = values + + queries = split_heads(queries, num_heads) + queries *= (num_units // num_heads)**-0.5 + + heads, attn = dot_product_attention( + queries, keys, values, mode, mask=san_mask, dropout=dropout) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = tf.layers.conv1d(combined, num_units, 1) + outputs = outputs + fsmn_memory + + if not return_attention: + return outputs + return outputs, attn + + +def multi_head_attention_wpa(num_heads, + queries, + memory, + mode, + attention_left_window=-1, + attention_right_window=0, + num_units=None, + mask=None, + cache=None, + max_id_cache=None, + dropout=0.0, + mono=False, + peak_delay=-1, + return_attention=False): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + + Args: + num_heads: The number of attention heads. + queries: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mode: A ``tf.estimator.ModeKeys`` mode. + num_units: The number of hidden units. If not set, it is set to the input + dimension. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + dropout: The probability to drop units from the inputs. + return_attention: Return the attention head probabilities in addition to the + context. + + Returns: + The concatenated attention context of each head and the attention + probabilities (if :obj:`return_attention` is set). + """ + num_units = num_units or queries.get_shape().as_list()[-1] + + if num_units % num_heads != 0: + raise ValueError('Multi head attention requires that num_units is a' + ' multiple of {}'.format(num_heads)) + + if memory is None: + queries, keys, values = fused_projection( + queries, num_units, num_outputs=3) + + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + if cache is not None: + keys = tf.concat([cache['self_keys'], keys], axis=2) + values = tf.concat([cache['self_values'], values], axis=2) + cache['self_keys'] = keys + cache['self_values'] = values + else: + queries = tf.layers.conv1d(queries, num_units, 1) + + if cache is not None: + + def _project_and_split(): + k, v = fused_projection(memory, num_units, num_outputs=2) + return split_heads(k, num_heads), split_heads(v, num_heads) + + keys, values = tf.cond( + tf.equal(tf.shape(cache['memory_keys'])[2], 0), + true_fn=_project_and_split, + false_fn=lambda: + (cache['memory_keys'], cache['memory_values'])) + cache['memory_keys'] = keys + cache['memory_values'] = values + else: + keys, values = fused_projection(memory, num_units, num_outputs=2) + keys = split_heads(keys, num_heads) + values = split_heads(values, num_heads) + + queries = split_heads(queries, num_heads) + queries *= (num_units // num_heads)**-0.5 + + heads, attn = dot_product_attention_wpa( + num_heads, + queries, + keys, + values, + mode, + attention_left_window=attention_left_window, + attention_right_window=attention_right_window, + mask=mask, + max_id_cache=max_id_cache, + mono=mono, + peak_delay=peak_delay, + dropout=dropout) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = tf.layers.conv1d(combined, num_units, 1) + + if not return_attention: + return outputs + return outputs, attn + + +def feed_forward(x, inner_dim, mode, dropout=0.0, mask=None): + """Implements the Transformer's "Feed Forward" layer. + + .. math:: + + ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 + + Args: + x: The input. + inner_dim: The number of units of the inner linear transformation. + mode: A ``tf.estimator.ModeKeys`` mode. + dropout: The probability to drop units from the inner transformation. + + Returns: + The transformed input. + """ + input_dim = x.get_shape().as_list()[-1] + + if mask is not None: + x = x * tf.expand_dims(mask, -1) + + inner = tf.layers.conv1d( + x, inner_dim, 3, padding='same', activation=tf.nn.relu) + + if mask is not None: + inner = inner * tf.expand_dims(mask, -1) + inner = tf.layers.dropout(inner, rate=dropout, training=mode) + outer = tf.layers.conv1d(inner, input_dim, 1) + + return outer + + +def feed_forward_ori(x, inner_dim, mode, dropout=0.0): + """Implements the Transformer's "Feed Forward" layer. + + .. math:: + + ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 + + Args: + x: The input. + inner_dim: The number of units of the inner linear transformation. + mode: A ``tf.estimator.ModeKeys`` mode. + dropout: The probability to drop units from the inner transformation. + + Returns: + The transformed input. + """ + input_dim = x.get_shape().as_list()[-1] + + inner = tf.layers.conv1d(x, inner_dim, 1, activation=tf.nn.relu) + inner = tf.layers.dropout(inner, rate=dropout, training=mode) + outer = tf.layers.conv1d(inner, input_dim, 1) + + return outer + + +def norm(inputs): + """Layer normalizes :obj:`inputs`.""" + return tf.contrib.layers.layer_norm(inputs, begin_norm_axis=-1) + + +def drop_and_add(inputs, outputs, mode, dropout=0.1): + """Drops units in the outputs and adds the previous values. + + Args: + inputs: The input of the previous layer. + outputs: The output of the previous layer. + mode: A ``tf.estimator.ModeKeys`` mode. + dropout: The probability to drop units in :obj:`outputs`. + + Returns: + The residual and normalized output. + """ + outputs = tf.layers.dropout(outputs, rate=dropout, training=mode) + + input_dim = inputs.get_shape().as_list()[-1] + output_dim = outputs.get_shape().as_list()[-1] + + if input_dim == output_dim: + outputs += inputs + return outputs + + +class FeedForwardNetwork(tf.keras.layers.Layer): + """Implements the Transformer's "Feed Forward" layer. + + .. math:: + + ffn(x) = max(0, x*W_1 + b_1)*W_2 + b_2 + + Note: + Object-oriented implementation for TensorFlow 2.0. + """ + + def __init__(self, + inner_dim, + output_dim, + dropout=0.1, + activation=tf.nn.relu, + **kwargs): + """Initializes this layer. + + Args: + inner_dim: The number of units of the inner linear transformation. + output_dim: The number of units of the ouput linear transformation. + dropout: The probability to drop units from the activation output. + activation: The activation function to apply between the two linear + transformations. + kwargs: Additional layer arguments. + """ + super(FeedForwardNetwork, self).__init__(**kwargs) + self.inner = tf.keras.layers.Dense( + inner_dim, activation=activation, name='inner') + self.outer = tf.keras.layers.Dense(output_dim, name='outer') + self.dropout = dropout + + def call(self, inputs, training=None): # pylint: disable=arguments-differ + """Runs the layer.""" + inner = self.inner(inputs) + inner = tf.layers.dropout(inner, self.dropout, training=training) + return self.outer(inner) + + +class MultiHeadAttention(tf.keras.layers.Layer): + """Computes the multi-head attention as described in + https://arxiv.org/abs/1706.03762. + + Note: + Object-oriented implementation for TensorFlow 2.0. + """ + + def __init__(self, + num_heads, + num_units, + dropout=0.1, + return_attention=False, + **kwargs): + """Initializes this layers. + + Args: + num_heads: The number of attention heads. + num_units: The number of hidden units. + dropout: The probability to drop units from the inputs. + return_attention: If ``True``, also return the attention weights of the + first head. + kwargs: Additional layer arguments. + """ + super(MultiHeadAttention, self).__init__(**kwargs) + if num_units % num_heads != 0: + raise ValueError( + 'Multi head attention requires that num_units is a' + ' multiple of %s' % num_heads) + self.num_heads = num_heads + self.num_units = num_units + self.linear_queries = tf.keras.layers.Dense( + num_units, name='linear_queries') + self.linear_keys = tf.keras.layers.Dense(num_units, name='linear_keys') + self.linear_values = tf.keras.layers.Dense( + num_units, name='linear_values') + self.linear_output = tf.keras.layers.Dense( + num_units, name='linear_output') + self.dropout = dropout + self.return_attention = return_attention + + def call(self, inputs, memory=None, mask=None, cache=None, training=None): # pylint: disable=arguments-differ + """Runs the layer. + + Args: + inputs: The sequence of queries. A tensor of shape :math:`[B, T_1, ...]`. + memory: The sequence to attend. A tensor of shape :math:`[B, T_2, ...]`. + If ``None``, computes self-attention. + mask: A ``tf.Tensor`` applied to the dot product. + cache: A dictionary containing pre-projected keys and values. + training: Run in training mode. + + Returns: + A tuple with the attention context, the updated cache and the attention + probabilities of the first head (if :obj:`return_attention` is ``True``). + """ + + def _compute_kv(x): + keys = self.linear_keys(x) + keys = split_heads(keys, self.num_heads) + values = self.linear_values(x) + values = split_heads(values, self.num_heads) + return keys, values + + # Compute queries. + queries = self.linear_queries(inputs) + queries = split_heads(queries, self.num_heads) + queries *= (self.num_units // self.num_heads)**-0.5 + + # Compute keys and values. + if memory is None: + keys, values = _compute_kv(inputs) + if cache: + keys = tf.concat([cache[0], keys], axis=2) + values = tf.concat([cache[1], values], axis=2) + else: + if cache: + if not self.linear_keys.built: + # Ensure that the variable names are not impacted by the tf.cond name + # scope if the layers have not already been built. + with tf.name_scope(self.linear_keys.name): + self.linear_keys.build(memory.shape) + with tf.name_scope(self.linear_values.name): + self.linear_values.build(memory.shape) + keys, values = tf.cond( + tf.equal(tf.shape(cache[0])[2], 0), + true_fn=lambda: _compute_kv(memory), + false_fn=lambda: cache) + else: + keys, values = _compute_kv(memory) + + cache = (keys, values) + + # Dot product attention. + dot = tf.matmul(queries, keys, transpose_b=True) + if mask is not None: + mask = tf.expand_dims(tf.cast(mask, tf.float32), + 1) # Broadcast on heads dimension. + dot = tf.cast( + tf.cast(dot, tf.float32) * mask + + ((1.0 - mask) * tf.float32.min), dot.dtype) # yapf:disable + attn = tf.cast(tf.nn.softmax(tf.cast(dot, tf.float32)), dot.dtype) + drop_attn = tf.layers.dropout(attn, self.dropout, training=training) + heads = tf.matmul(drop_attn, values) + + # Concatenate all heads output. + combined = combine_heads(heads) + outputs = self.linear_output(combined) + if self.return_attention: + return outputs, cache, attn + return outputs, cache diff --git a/modelscope/models/audio/tts/am/sambert_hifi_16k.py b/modelscope/models/audio/tts/am/sambert_hifi_16k.py new file mode 100644 index 00000000..2db9abc6 --- /dev/null +++ b/modelscope/models/audio/tts/am/sambert_hifi_16k.py @@ -0,0 +1,255 @@ +import io +import os +from typing import Any, Dict, Optional, Union + +import numpy as np +import tensorflow as tf +from sklearn.preprocessing import MultiLabelBinarizer + +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.constant import ModelFile, Tasks +from .models import create_model +from .text.symbols import load_symbols +from .text.symbols_dict import SymbolsDict + +__all__ = ['SambertNetHifi16k'] + + +def multi_label_symbol_to_sequence(my_classes, my_symbol): + one_hot = MultiLabelBinarizer(my_classes) + tokens = my_symbol.strip().split(' ') + sequences = [] + for token in tokens: + sequences.append(tuple(token.split('&'))) + # sequences.append(tuple(['~'])) # sequence length minus 1 to ignore EOS ~ + return one_hot.fit_transform(sequences) + + +@MODELS.register_module(Tasks.text_to_speech, module_name=r'sambert_hifi_16k') +class SambertNetHifi16k(Model): + + def __init__(self, + model_dir, + pitch_control_str='', + duration_control_str='', + energy_control_str='', + *args, + **kwargs): + tf.reset_default_graph() + local_ckpt_path = os.path.join(ModelFile.TF_CHECKPOINT_FOLDER, 'ckpt') + self._ckpt_path = os.path.join(model_dir, local_ckpt_path) + self._dict_path = os.path.join(model_dir, 'dicts') + self._hparams = tf.contrib.training.HParams(**kwargs) + values = self._hparams.values() + hp = [' {}:{}'.format(name, values[name]) for name in sorted(values)] + print('Hyperparameters:\n' + '\n'.join(hp)) + super().__init__(self._ckpt_path, *args, **kwargs) + model_name = 'robutrans' + self._lfeat_type_list = self._hparams.lfeat_type_list.strip().split( + ',') + sy, tone, syllable_flag, word_segment, emo_category, speaker = load_symbols( + self._dict_path) + self._sy = sy + self._tone = tone + self._syllable_flag = syllable_flag + self._word_segment = word_segment + self._emo_category = emo_category + self._speaker = speaker + self._inputs_dim = dict() + for lfeat_type in self._lfeat_type_list: + if lfeat_type == 'sy': + self._inputs_dim[lfeat_type] = len(sy) + elif lfeat_type == 'tone': + self._inputs_dim[lfeat_type] = len(tone) + elif lfeat_type == 'syllable_flag': + self._inputs_dim[lfeat_type] = len(syllable_flag) + elif lfeat_type == 'word_segment': + self._inputs_dim[lfeat_type] = len(word_segment) + elif lfeat_type == 'emo_category': + self._inputs_dim[lfeat_type] = len(emo_category) + elif lfeat_type == 'speaker': + self._inputs_dim[lfeat_type] = len(speaker) + + self._symbols_dict = SymbolsDict(sy, tone, syllable_flag, word_segment, + emo_category, speaker, + self._inputs_dim, + self._lfeat_type_list) + dim_inputs = sum(self._inputs_dim.values( + )) - self._inputs_dim['speaker'] - self._inputs_dim['emo_category'] + inputs = tf.placeholder(tf.float32, [1, None, dim_inputs], 'inputs') + inputs_emotion = tf.placeholder( + tf.float32, [1, None, self._inputs_dim['emo_category']], + 'inputs_emotion') + inputs_speaker = tf.placeholder(tf.float32, + [1, None, self._inputs_dim['speaker']], + 'inputs_speaker') + + input_lengths = tf.placeholder(tf.int32, [1], 'input_lengths') + pitch_contours_scale = tf.placeholder(tf.float32, [1, None], + 'pitch_contours_scale') + energy_contours_scale = tf.placeholder(tf.float32, [1, None], + 'energy_contours_scale') + duration_scale = tf.placeholder(tf.float32, [1, None], + 'duration_scale') + + with tf.variable_scope('model') as _: + self._model = create_model(model_name, self._hparams) + self._model.initialize( + inputs, + inputs_emotion, + inputs_speaker, + input_lengths, + duration_scales=duration_scale, + pitch_scales=pitch_contours_scale, + energy_scales=energy_contours_scale) + self._mel_spec = self._model.mel_outputs[0] + self._duration_outputs = self._model.duration_outputs[0] + self._duration_outputs_ = self._model.duration_outputs_[0] + self._pitch_contour_outputs = self._model.pitch_contour_outputs[0] + self._energy_contour_outputs = self._model.energy_contour_outputs[ + 0] + self._embedded_inputs_emotion = self._model.embedded_inputs_emotion[ + 0] + self._embedding_fsmn_outputs = self._model.embedding_fsmn_outputs[ + 0] + self._encoder_outputs = self._model.encoder_outputs[0] + self._pitch_embeddings = self._model.pitch_embeddings[0] + self._energy_embeddings = self._model.energy_embeddings[0] + self._LR_outputs = self._model.LR_outputs[0] + self._postnet_fsmn_outputs = self._model.postnet_fsmn_outputs[0] + self._attention_h = self._model.attention_h + self._attention_x = self._model.attention_x + + print('Loading checkpoint: %s' % self._ckpt_path) + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + self._session = tf.Session(config=config) + self._session.run(tf.global_variables_initializer()) + + saver = tf.train.Saver() + saver.restore(self._session, self._ckpt_path) + + duration_cfg_lst = [] + if len(duration_control_str) != 0: + for item in duration_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + duration_cfg_lst.append((float(percent), float(scale))) + + self._duration_cfg_lst = duration_cfg_lst + + pitch_contours_cfg_lst = [] + if len(pitch_control_str) != 0: + for item in pitch_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + pitch_contours_cfg_lst.append( + (float(percent), float(scale))) + + self._pitch_contours_cfg_lst = pitch_contours_cfg_lst + + energy_contours_cfg_lst = [] + if len(energy_control_str) != 0: + for item in energy_control_str.strip().split('|'): + percent, scale = item.lstrip('(').rstrip(')').split(',') + energy_contours_cfg_lst.append( + (float(percent), float(scale))) + + self._energy_contours_cfg_lst = energy_contours_cfg_lst + + def forward(self, text): + cleaner_names = [x.strip() for x in self._hparams.cleaners.split(',')] + + lfeat_symbol = text.strip().split(' ') + lfeat_symbol_separate = [''] * int(len(self._lfeat_type_list)) + for this_lfeat_symbol in lfeat_symbol: + this_lfeat_symbol = this_lfeat_symbol.strip('{').strip('}').split( + '$') + if len(this_lfeat_symbol) != len(self._lfeat_type_list): + raise Exception( + 'Length of this_lfeat_symbol in training data' + + ' is not equal to the length of lfeat_type_list, ' + + str(len(this_lfeat_symbol)) + ' VS. ' + + str(len(self._lfeat_type_list))) + index = 0 + while index < len(lfeat_symbol_separate): + lfeat_symbol_separate[index] = lfeat_symbol_separate[ + index] + this_lfeat_symbol[index] + ' ' + index = index + 1 + + index = 0 + lfeat_type = self._lfeat_type_list[index] + sequence = self._symbols_dict.symbol_to_sequence( + lfeat_symbol_separate[index].strip(), lfeat_type, cleaner_names) + sequence_array = np.asarray( + sequence[:-1], + dtype=np.int32) # sequence length minus 1 to ignore EOS ~ + inputs = np.eye( + self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] + index = index + 1 + while index < len(self._lfeat_type_list) - 2: + lfeat_type = self._lfeat_type_list[index] + sequence = self._symbols_dict.symbol_to_sequence( + lfeat_symbol_separate[index].strip(), lfeat_type, + cleaner_names) + sequence_array = np.asarray( + sequence[:-1], + dtype=np.int32) # sequence length minus 1 to ignore EOS ~ + inputs_temp = np.eye( + self._inputs_dim[lfeat_type], dtype=np.float32)[sequence_array] + inputs = np.concatenate((inputs, inputs_temp), axis=1) + index = index + 1 + seq = inputs + + lfeat_type = 'emo_category' + inputs_emotion = multi_label_symbol_to_sequence( + self._emo_category, lfeat_symbol_separate[index].strip()) + # inputs_emotion = inputs_emotion * 1.5 + index = index + 1 + + lfeat_type = 'speaker' + inputs_speaker = multi_label_symbol_to_sequence( + self._speaker, lfeat_symbol_separate[index].strip()) + + duration_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in self._duration_cfg_lst: + duration_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + pitch_contours_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in self._pitch_contours_cfg_lst: + pitch_contours_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + energy_contours_scale = np.ones((len(seq), ), dtype=np.float32) + start_idx = 0 + for (percent, scale) in self._energy_contours_cfg_lst: + energy_contours_scale[start_idx:start_idx + + int(percent * len(seq))] = scale + start_idx += int(percent * len(seq)) + + feed_dict = { + self._model.inputs: [np.asarray(seq, dtype=np.float32)], + self._model.inputs_emotion: + [np.asarray(inputs_emotion, dtype=np.float32)], + self._model.inputs_speaker: + [np.asarray(inputs_speaker, dtype=np.float32)], + self._model.input_lengths: + np.asarray([len(seq)], dtype=np.int32), + self._model.duration_scales: [duration_scale], + self._model.pitch_scales: [pitch_contours_scale], + self._model.energy_scales: [energy_contours_scale] + } + + result = self._session.run([ + self._mel_spec, self._duration_outputs, self._duration_outputs_, + self._pitch_contour_outputs, self._embedded_inputs_emotion, + self._embedding_fsmn_outputs, self._encoder_outputs, + self._pitch_embeddings, self._LR_outputs, + self._postnet_fsmn_outputs, self._energy_contour_outputs, + self._energy_embeddings, self._attention_x, self._attention_h + ], feed_dict=feed_dict) # yapf:disable + return result[0] diff --git a/modelscope/models/audio/tts/am/text/__init__.py b/modelscope/models/audio/tts/am/text/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/modelscope/models/audio/tts/am/text/cleaners.py b/modelscope/models/audio/tts/am/text/cleaners.py new file mode 100755 index 00000000..19d838d1 --- /dev/null +++ b/modelscope/models/audio/tts/am/text/cleaners.py @@ -0,0 +1,89 @@ +''' +Cleaners are transformations that run over the input text at both training and eval time. + +Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" +hyperparameter. Some cleaners are English-specific. You'll typically want to use: + 1. "english_cleaners" for English text + 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using + the Unidecode library (https://pypi.python.org/pypi/Unidecode) + 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update + the symbols in symbols.py to match your data). +''' + +import re + +from unidecode import unidecode + +from .numbers import normalize_numbers + +# Regular expression matching whitespace: +_whitespace_re = re.compile(r'\s+') + +# List of (regular expression, replacement) pairs for abbreviations: +_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) + for x in [ + ('mrs', 'misess'), + ('mr', 'mister'), + ('dr', 'doctor'), + ('st', 'saint'), + ('co', 'company'), + ('jr', 'junior'), + ('maj', 'major'), + ('gen', 'general'), + ('drs', 'doctors'), + ('rev', 'reverend'), + ('lt', 'lieutenant'), + ('hon', 'honorable'), + ('sgt', 'sergeant'), + ('capt', 'captain'), + ('esq', 'esquire'), + ('ltd', 'limited'), + ('col', 'colonel'), + ('ft', 'fort'), ]] # yapf:disable + + +def expand_abbreviations(text): + for regex, replacement in _abbreviations: + text = re.sub(regex, replacement, text) + return text + + +def expand_numbers(text): + return normalize_numbers(text) + + +def lowercase(text): + return text.lower() + + +def collapse_whitespace(text): + return re.sub(_whitespace_re, ' ', text) + + +def convert_to_ascii(text): + return unidecode(text) + + +def basic_cleaners(text): + '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def transliteration_cleaners(text): + '''Pipeline for non-English text that transliterates to ASCII.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = collapse_whitespace(text) + return text + + +def english_cleaners(text): + '''Pipeline for English text, including number and abbreviation expansion.''' + text = convert_to_ascii(text) + text = lowercase(text) + text = expand_numbers(text) + text = expand_abbreviations(text) + text = collapse_whitespace(text) + return text diff --git a/modelscope/models/audio/tts/am/text/cmudict.py b/modelscope/models/audio/tts/am/text/cmudict.py new file mode 100755 index 00000000..b4da4be9 --- /dev/null +++ b/modelscope/models/audio/tts/am/text/cmudict.py @@ -0,0 +1,64 @@ +import re + +valid_symbols = [ + 'AA', 'AA0', 'AA1', 'AA2', 'AE', 'AE0', 'AE1', 'AE2', 'AH', 'AH0', 'AH1', + 'AH2', 'AO', 'AO0', 'AO1', 'AO2', 'AW', 'AW0', 'AW1', 'AW2', 'AY', 'AY0', + 'AY1', 'AY2', 'B', 'CH', 'D', 'DH', 'EH', 'EH0', 'EH1', 'EH2', 'ER', 'ER0', + 'ER1', 'ER2', 'EY', 'EY0', 'EY1', 'EY2', 'F', 'G', 'HH', 'IH', 'IH0', + 'IH1', 'IH2', 'IY', 'IY0', 'IY1', 'IY2', 'JH', 'K', 'L', 'M', 'N', 'NG', + 'OW', 'OW0', 'OW1', 'OW2', 'OY', 'OY0', 'OY1', 'OY2', 'P', 'R', 'S', 'SH', + 'T', 'TH', 'UH', 'UH0', 'UH1', 'UH2', 'UW', 'UW0', 'UW1', 'UW2', 'V', 'W', + 'Y', 'Z', 'ZH' +] + +_valid_symbol_set = set(valid_symbols) + + +class CMUDict: + '''Thin wrapper around CMUDict data. http://www.speech.cs.cmu.edu/cgi-bin/cmudict''' + + def __init__(self, file_or_path, keep_ambiguous=True): + if isinstance(file_or_path, str): + with open(file_or_path, encoding='latin-1') as f: + entries = _parse_cmudict(f) + else: + entries = _parse_cmudict(file_or_path) + if not keep_ambiguous: + entries = { + word: pron + for word, pron in entries.items() if len(pron) == 1 + } + self._entries = entries + + def __len__(self): + return len(self._entries) + + def lookup(self, word): + '''Returns list of ARPAbet pronunciations of the given word.''' + return self._entries.get(word.upper()) + + +_alt_re = re.compile(r'\([0-9]+\)') + + +def _parse_cmudict(file): + cmudict = {} + for line in file: + if len(line) and (line[0] >= 'A' and line[0] <= 'Z' or line[0] == "'"): + parts = line.split(' ') + word = re.sub(_alt_re, '', parts[0]) + pronunciation = _get_pronunciation(parts[1]) + if pronunciation: + if word in cmudict: + cmudict[word].append(pronunciation) + else: + cmudict[word] = [pronunciation] + return cmudict + + +def _get_pronunciation(s): + parts = s.strip().split(' ') + for part in parts: + if part not in _valid_symbol_set: + return None + return ' '.join(parts) diff --git a/modelscope/models/audio/tts/am/text/numbers.py b/modelscope/models/audio/tts/am/text/numbers.py new file mode 100755 index 00000000..d9453fee --- /dev/null +++ b/modelscope/models/audio/tts/am/text/numbers.py @@ -0,0 +1,70 @@ +import re + +import inflect + +_inflect = inflect.engine() +_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') +_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') +_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') +_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') +_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') +_number_re = re.compile(r'[0-9]+') + + +def _remove_commas(m): + return m.group(1).replace(',', '') + + +def _expand_decimal_point(m): + return m.group(1).replace('.', ' point ') + + +def _expand_dollars(m): + match = m.group(1) + parts = match.split('.') + if len(parts) > 2: + return match + ' dollars' # Unexpected format + dollars = int(parts[0]) if parts[0] else 0 + cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 + if dollars and cents: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) + elif dollars: + dollar_unit = 'dollar' if dollars == 1 else 'dollars' + return '%s %s' % (dollars, dollar_unit) + elif cents: + cent_unit = 'cent' if cents == 1 else 'cents' + return '%s %s' % (cents, cent_unit) + else: + return 'zero dollars' + + +def _expand_ordinal(m): + return _inflect.number_to_words(m.group(0)) + + +def _expand_number(m): + num = int(m.group(0)) + if num > 1000 and num < 3000: + if num == 2000: + return 'two thousand' + elif num > 2000 and num < 2010: + return 'two thousand ' + _inflect.number_to_words(num % 100) + elif num % 100 == 0: + return _inflect.number_to_words(num // 100) + ' hundred' + else: + return _inflect.number_to_words( + num, andword='', zero='oh', group=2).replace(', ', ' ') + else: + return _inflect.number_to_words(num, andword='') + + +def normalize_numbers(text): + text = re.sub(_comma_number_re, _remove_commas, text) + text = re.sub(_pounds_re, r'\1 pounds', text) + text = re.sub(_dollars_re, _expand_dollars, text) + text = re.sub(_decimal_number_re, _expand_decimal_point, text) + text = re.sub(_ordinal_re, _expand_ordinal, text) + text = re.sub(_number_re, _expand_number, text) + return text diff --git a/modelscope/models/audio/tts/am/text/symbols.py b/modelscope/models/audio/tts/am/text/symbols.py new file mode 100644 index 00000000..a7715cca --- /dev/null +++ b/modelscope/models/audio/tts/am/text/symbols.py @@ -0,0 +1,95 @@ +''' +Defines the set of symbols used in text input to the model. + +The default is a set of ASCII characters that works well for English or text that has been run +through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. +''' +import codecs +import os + +_pad = '_' +_eos = '~' +_mask = '@[MASK]' + + +def load_symbols(dict_path): + _characters = '' + _ch_symbols = [] + sy_dict_name = 'sy_dict.txt' + sy_dict_path = os.path.join(dict_path, sy_dict_name) + f = codecs.open(sy_dict_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_symbols.append(line) + + _arpabet = ['@' + s for s in _ch_symbols] + + # Export all symbols: + sy = list(_characters) + _arpabet + [_pad, _eos, _mask] + + _characters = '' + + _ch_tones = [] + tone_dict_name = 'tone_dict.txt' + tone_dict_path = os.path.join(dict_path, tone_dict_name) + f = codecs.open(tone_dict_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_tones.append(line) + + # Export all tones: + tone = list(_characters) + _ch_tones + [_pad, _eos, _mask] + + _characters = '' + + _ch_syllable_flags = [] + syllable_flag_name = 'syllable_flag_dict.txt' + syllable_flag_path = os.path.join(dict_path, syllable_flag_name) + f = codecs.open(syllable_flag_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_syllable_flags.append(line) + + # Export all syllable_flags: + syllable_flag = list(_characters) + _ch_syllable_flags + [ + _pad, _eos, _mask + ] + + _characters = '' + + _ch_word_segments = [] + word_segment_name = 'word_segment_dict.txt' + word_segment_path = os.path.join(dict_path, word_segment_name) + f = codecs.open(word_segment_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_word_segments.append(line) + + # Export all syllable_flags: + word_segment = list(_characters) + _ch_word_segments + [_pad, _eos, _mask] + + _characters = '' + + _ch_emo_types = [] + emo_category_name = 'emo_category_dict.txt' + emo_category_path = os.path.join(dict_path, emo_category_name) + f = codecs.open(emo_category_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_emo_types.append(line) + + emo_category = list(_characters) + _ch_emo_types + [_pad, _eos, _mask] + + _characters = '' + + _ch_speakers = [] + speaker_name = 'speaker_dict.txt' + speaker_path = os.path.join(dict_path, speaker_name) + f = codecs.open(speaker_path, 'r') + for line in f: + line = line.strip('\r\n') + _ch_speakers.append(line) + + # Export all syllable_flags: + speaker = list(_characters) + _ch_speakers + [_pad, _eos, _mask] + return sy, tone, syllable_flag, word_segment, emo_category, speaker diff --git a/modelscope/models/audio/tts/am/text/symbols_dict.py b/modelscope/models/audio/tts/am/text/symbols_dict.py new file mode 100644 index 00000000..e8f7ed19 --- /dev/null +++ b/modelscope/models/audio/tts/am/text/symbols_dict.py @@ -0,0 +1,200 @@ +import re +import sys + +from .cleaners import (basic_cleaners, english_cleaners, + transliteration_cleaners) + + +class SymbolsDict: + + def __init__(self, sy, tone, syllable_flag, word_segment, emo_category, + speaker, inputs_dim, lfeat_type_list): + self._inputs_dim = inputs_dim + self._lfeat_type_list = lfeat_type_list + self._sy_to_id = {s: i for i, s in enumerate(sy)} + self._id_to_sy = {i: s for i, s in enumerate(sy)} + self._tone_to_id = {s: i for i, s in enumerate(tone)} + self._id_to_tone = {i: s for i, s in enumerate(tone)} + self._syllable_flag_to_id = {s: i for i, s in enumerate(syllable_flag)} + self._id_to_syllable_flag = {i: s for i, s in enumerate(syllable_flag)} + self._word_segment_to_id = {s: i for i, s in enumerate(word_segment)} + self._id_to_word_segment = {i: s for i, s in enumerate(word_segment)} + self._emo_category_to_id = {s: i for i, s in enumerate(emo_category)} + self._id_to_emo_category = {i: s for i, s in enumerate(emo_category)} + self._speaker_to_id = {s: i for i, s in enumerate(speaker)} + self._id_to_speaker = {i: s for i, s in enumerate(speaker)} + print('_sy_to_id: ') + print(self._sy_to_id) + print('_tone_to_id: ') + print(self._tone_to_id) + print('_syllable_flag_to_id: ') + print(self._syllable_flag_to_id) + print('_word_segment_to_id: ') + print(self._word_segment_to_id) + print('_emo_category_to_id: ') + print(self._emo_category_to_id) + print('_speaker_to_id: ') + print(self._speaker_to_id) + self._curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)') + self._cleaners = { + basic_cleaners.__name__: basic_cleaners, + transliteration_cleaners.__name__: transliteration_cleaners, + english_cleaners.__name__: english_cleaners + } + + def _clean_text(self, text, cleaner_names): + for name in cleaner_names: + cleaner = self._cleaners.get(name) + if not cleaner: + raise Exception('Unknown cleaner: %s' % name) + text = cleaner(text) + return text + + def _sy_to_sequence(self, sy): + return [self._sy_to_id[s] for s in sy if self._should_keep_sy(s)] + + def _arpabet_to_sequence(self, text): + return self._sy_to_sequence(['@' + s for s in text.split()]) + + def _should_keep_sy(self, s): + return s in self._sy_to_id and s != '_' and s != '~' + + def symbol_to_sequence(self, this_lfeat_symbol, lfeat_type, cleaner_names): + sequence = [] + if lfeat_type == 'sy': + this_lfeat_symbol = this_lfeat_symbol.strip().split(' ') + this_lfeat_symbol_format = '' + index = 0 + while index < len(this_lfeat_symbol): + this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[ + index] + '}' + ' ' + index = index + 1 + sequence = self.text_to_sequence(this_lfeat_symbol_format, + cleaner_names) + elif lfeat_type == 'tone': + sequence = self.tone_to_sequence(this_lfeat_symbol) + elif lfeat_type == 'syllable_flag': + sequence = self.syllable_flag_to_sequence(this_lfeat_symbol) + elif lfeat_type == 'word_segment': + sequence = self.word_segment_to_sequence(this_lfeat_symbol) + elif lfeat_type == 'emo_category': + sequence = self.emo_category_to_sequence(this_lfeat_symbol) + elif lfeat_type == 'speaker': + sequence = self.speaker_to_sequence(this_lfeat_symbol) + else: + raise Exception('Unknown lfeat type: %s' % lfeat_type) + + return sequence + + def text_to_sequence(self, text, cleaner_names): + '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. + + The text can optionally have ARPAbet sequences enclosed in curly braces embedded + in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street." + + Args: + text: string to convert to a sequence + cleaner_names: names of the cleaner functions to run the text through + + Returns: + List of integers corresponding to the symbols in the text + ''' + sequence = [] + + # Check for curly braces and treat their contents as ARPAbet: + while len(text): + m = self._curly_re.match(text) + if not m: + sequence += self._sy_to_sequence( + self._clean_text(text, cleaner_names)) + break + sequence += self._sy_to_sequence( + self._clean_text(m.group(1), cleaner_names)) + sequence += self._arpabet_to_sequence(m.group(2)) + text = m.group(3) + + # Append EOS token + sequence.append(self._sy_to_id['~']) + return sequence + + def tone_to_sequence(self, tone): + tones = tone.strip().split(' ') + sequence = [] + for this_tone in tones: + sequence.append(self._tone_to_id[this_tone]) + sequence.append(self._tone_to_id['~']) + return sequence + + def syllable_flag_to_sequence(self, syllable_flag): + syllable_flags = syllable_flag.strip().split(' ') + sequence = [] + for this_syllable_flag in syllable_flags: + sequence.append(self._syllable_flag_to_id[this_syllable_flag]) + sequence.append(self._syllable_flag_to_id['~']) + return sequence + + def word_segment_to_sequence(self, word_segment): + word_segments = word_segment.strip().split(' ') + sequence = [] + for this_word_segment in word_segments: + sequence.append(self._word_segment_to_id[this_word_segment]) + sequence.append(self._word_segment_to_id['~']) + return sequence + + def emo_category_to_sequence(self, emo_type): + emo_categories = emo_type.strip().split(' ') + sequence = [] + for this_category in emo_categories: + sequence.append(self._emo_category_to_id[this_category]) + sequence.append(self._emo_category_to_id['~']) + return sequence + + def speaker_to_sequence(self, speaker): + speakers = speaker.strip().split(' ') + sequence = [] + for this_speaker in speakers: + sequence.append(self._speaker_to_id[this_speaker]) + sequence.append(self._speaker_to_id['~']) + return sequence + + def sequence_to_symbol(self, sequence): + result = '' + pre_lfeat_dim = 0 + for lfeat_type in self._lfeat_type_list: + current_one_hot_sequence = sequence[:, pre_lfeat_dim:pre_lfeat_dim + + self._inputs_dim[lfeat_type]] + current_sequence = current_one_hot_sequence.argmax(1) + length = current_sequence.shape[0] + + index = 0 + while index < length: + this_sequence = current_sequence[index] + s = '' + if lfeat_type == 'sy': + s = self._id_to_sy[this_sequence] + if len(s) > 1 and s[0] == '@': + s = s[1:] + elif lfeat_type == 'tone': + s = self._id_to_tone[this_sequence] + elif lfeat_type == 'syllable_flag': + s = self._id_to_syllable_flag[this_sequence] + elif lfeat_type == 'word_segment': + s = self._id_to_word_segment[this_sequence] + elif lfeat_type == 'emo_category': + s = self._id_to_emo_category[this_sequence] + elif lfeat_type == 'speaker': + s = self._id_to_speaker[this_sequence] + else: + raise Exception('Unknown lfeat type: %s' % lfeat_type) + + if index == 0: + result = result + lfeat_type + ': ' + + result = result + '{' + s + '}' + + if index == length - 1: + result = result + '; ' + + index = index + 1 + pre_lfeat_dim = pre_lfeat_dim + self._inputs_dim[lfeat_type] + return result diff --git a/modelscope/models/audio/tts/frontend/__init__.py b/modelscope/models/audio/tts/frontend/__init__.py new file mode 100644 index 00000000..d7b1015d --- /dev/null +++ b/modelscope/models/audio/tts/frontend/__init__.py @@ -0,0 +1 @@ +from .generic_text_to_speech_frontend import * # noqa F403 diff --git a/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py b/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py new file mode 100644 index 00000000..ed34143f --- /dev/null +++ b/modelscope/models/audio/tts/frontend/generic_text_to_speech_frontend.py @@ -0,0 +1,39 @@ +import os +import zipfile +from typing import Any, Dict, List + +import ttsfrd + +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.audio.tts_exceptions import ( + TtsFrontendInitializeFailedException, + TtsFrontendLanguageTypeInvalidException) +from modelscope.utils.constant import Tasks + +__all__ = ['GenericTtsFrontend'] + + +@MODELS.register_module( + Tasks.text_to_speech, module_name=r'generic_tts_frontend') +class GenericTtsFrontend(Model): + + def __init__(self, model_dir='.', lang_type='pinyin', *args, **kwargs): + super().__init__(model_dir, *args, **kwargs) + frontend = ttsfrd.TtsFrontendEngine() + zip_file = os.path.join(model_dir, 'resource.zip') + self._res_path = os.path.join(model_dir, 'resource') + with zipfile.ZipFile(zip_file, 'r') as zip_ref: + zip_ref.extractall(model_dir) + if not frontend.initialize(self._res_path): + raise TtsFrontendInitializeFailedException( + 'resource invalid: {}'.format(self._res_path)) + if not frontend.set_lang_type(lang_type): + raise TtsFrontendLanguageTypeInvalidException( + 'language type invalid: {}, valid is pinyin and chenmix'. + format(lang_type)) + self._frontend = frontend + + def forward(self, data: str) -> Dict[str, List]: + result = self._frontend.gen_tacotron_symbols(data) + return {'texts': [s for s in result.splitlines() if s != '']} diff --git a/modelscope/models/audio/tts/vocoder/__init__.py b/modelscope/models/audio/tts/vocoder/__init__.py new file mode 100644 index 00000000..94f257f8 --- /dev/null +++ b/modelscope/models/audio/tts/vocoder/__init__.py @@ -0,0 +1 @@ +from .hifigan16k import * # noqa F403 diff --git a/modelscope/models/audio/tts/vocoder/hifigan16k.py b/modelscope/models/audio/tts/vocoder/hifigan16k.py new file mode 100644 index 00000000..0d917dbe --- /dev/null +++ b/modelscope/models/audio/tts/vocoder/hifigan16k.py @@ -0,0 +1,73 @@ +from __future__ import (absolute_import, division, print_function, + unicode_literals) +import argparse +import glob +import os +import time + +import json +import numpy as np +import torch +from scipy.io.wavfile import write + +from modelscope.models.base import Model +from modelscope.models.builder import MODELS +from modelscope.utils.audio.tts_exceptions import \ + TtsVocoderMelspecShapeMismatchException +from modelscope.utils.constant import ModelFile, Tasks +from .models import Generator + +__all__ = ['Hifigan16k', 'AttrDict'] +MAX_WAV_VALUE = 32768.0 + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print('Complete.') + return checkpoint_dict + + +class AttrDict(dict): + + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + +@MODELS.register_module(Tasks.text_to_speech, module_name=r'hifigan16k') +class Hifigan16k(Model): + + def __init__(self, model_dir, *args, **kwargs): + self._ckpt_path = os.path.join(model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + self._config = AttrDict(**kwargs) + + super().__init__(self._ckpt_path, *args, **kwargs) + if torch.cuda.is_available(): + torch.manual_seed(self._config.seed) + self._device = torch.device('cuda') + else: + self._device = torch.device('cpu') + self._generator = Generator(self._config).to(self._device) + state_dict_g = load_checkpoint(self._ckpt_path, self._device) + self._generator.load_state_dict(state_dict_g['generator']) + self._generator.eval() + self._generator.remove_weight_norm() + + def forward(self, melspec): + dim0 = list(melspec.shape)[-1] + if dim0 != 80: + raise TtsVocoderMelspecShapeMismatchException( + 'input melspec mismatch 0 dim require 80 but {}'.format(dim0)) + with torch.no_grad(): + x = melspec.T + x = torch.FloatTensor(x).to(self._device) + if len(x.shape) == 2: + x = x.unsqueeze(0) + y_g_hat = self._generator(x) + audio = y_g_hat.squeeze() + audio = audio * MAX_WAV_VALUE + audio = audio.cpu().numpy().astype('int16') + return audio diff --git a/modelscope/models/audio/tts/vocoder/models/__init__.py b/modelscope/models/audio/tts/vocoder/models/__init__.py new file mode 100644 index 00000000..b00eec9b --- /dev/null +++ b/modelscope/models/audio/tts/vocoder/models/__init__.py @@ -0,0 +1 @@ +from .models import Generator diff --git a/modelscope/models/audio/tts/vocoder/models/models.py b/modelscope/models/audio/tts/vocoder/models/models.py new file mode 100755 index 00000000..83fc7dc2 --- /dev/null +++ b/modelscope/models/audio/tts/vocoder/models/models.py @@ -0,0 +1,516 @@ +from distutils.version import LooseVersion + +import torch +import torch.nn as nn +import torch.nn.functional as F +from pytorch_wavelets import DWT1DForward +from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d +from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm + +from .utils import get_padding, init_weights + +is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7') + + +def stft(x, fft_size, hop_size, win_length, window): + """Perform STFT and convert to magnitude spectrogram. + + Args: + x (Tensor): Input signal tensor (B, T). + fft_size (int): FFT size. + hop_size (int): Hop size. + win_length (int): Window length. + window (str): Window function type. + + Returns: + Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1). + + """ + if is_pytorch_17plus: + x_stft = torch.stft( + x, fft_size, hop_size, win_length, window, return_complex=False) + else: + x_stft = torch.stft(x, fft_size, hop_size, win_length, window) + real = x_stft[..., 0] + imag = x_stft[..., 1] + + # NOTE(kan-bayashi): clamp is needed to avoid nan or inf + return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1) + + +LRELU_SLOPE = 0.1 + + +def get_padding_casual(kernel_size, dilation=1): + return int(kernel_size * dilation - dilation) + + +class Conv1dCasual(torch.nn.Module): + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + padding_mode='zeros'): + super(Conv1dCasual, self).__init__() + self.pad = padding + self.conv1d = weight_norm( + Conv1d( + in_channels, + out_channels, + kernel_size, + stride, + padding=0, + dilation=dilation, + groups=groups, + bias=bias, + padding_mode=padding_mode)) + self.conv1d.apply(init_weights) + + def forward(self, x): # bdt + # described starting from the last dimension and moving forward. + x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant') + x = self.conv1d(x) + return x + + def remove_weight_norm(self): + remove_weight_norm(self.conv1d) + + +class ConvTranspose1dCausal(torch.nn.Module): + """CausalConvTranspose1d module with customized initialization.""" + + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride, + padding=0): + """Initialize CausalConvTranspose1d module.""" + super(ConvTranspose1dCausal, self).__init__() + self.deconv = weight_norm( + ConvTranspose1d(in_channels, out_channels, kernel_size, stride)) + self.stride = stride + self.deconv.apply(init_weights) + self.pad = kernel_size - stride + + def forward(self, x): + """Calculate forward propagation. + Args: + x (Tensor): Input tensor (B, in_channels, T_in). + Returns: + Tensor: Output tensor (B, out_channels, T_out). + """ + # x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant") + return self.deconv(x)[:, :, :-self.pad] + + def remove_weight_norm(self): + remove_weight_norm(self.deconv) + + +class ResBlock1(torch.nn.Module): + + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock1, self).__init__() + self.h = h + self.convs1 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=dilation[i], + padding=get_padding_casual(kernel_size, dilation[i])) + for i in range(len(dilation)) + ]) + + self.convs2 = nn.ModuleList([ + Conv1dCasual( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding_casual(kernel_size, 1)) + for i in range(len(dilation)) + ]) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for layer in self.convs1: + layer.remove_weight_norm() + for layer in self.convs2: + layer.remove_weight_norm() + + +class Generator(torch.nn.Module): + + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + print('num_kernels={}, num_upsamples={}'.format( + self.num_kernels, self.num_upsamples)) + self.conv_pre = Conv1dCasual( + 80, h.upsample_initial_channel, 7, 1, padding=7 - 1) + resblock = ResBlock1 if h.resblock == '1' else ResBlock2 + + self.ups = nn.ModuleList() + self.repeat_ups = nn.ModuleList() + for i, (u, k) in enumerate( + zip(h.upsample_rates, h.upsample_kernel_sizes)): + upsample = nn.Sequential( + nn.Upsample(mode='nearest', scale_factor=u), + nn.LeakyReLU(LRELU_SLOPE), + Conv1dCasual( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + kernel_size=7, + stride=1, + padding=7 - 1)) + self.repeat_ups.append(upsample) + self.ups.append( + ConvTranspose1dCausal( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2**(i + 1)), + k, + u, + padding=(k - u) // 2)) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2**(i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = torch.sin(x) + x + # transconv + x1 = F.leaky_relu(x, LRELU_SLOPE) + x1 = self.ups[i](x1) + # repeat + x2 = self.repeat_ups[i](x) + x = x1 + x2 + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + return x + + def remove_weight_norm(self): + print('Removing weight norm...') + for layer in self.ups: + layer.remove_weight_norm() + for layer in self.repeat_ups: + layer[-1].remove_weight_norm() + for layer in self.resblocks: + layer.remove_weight_norm() + self.conv_pre.remove_weight_norm() + self.conv_post.remove_weight_norm() + + +class DiscriminatorP(torch.nn.Module): + + def __init__(self, + period, + kernel_size=5, + stride=3, + use_spectral_norm=False): + super(DiscriminatorP, self).__init__() + self.period = period + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + 1, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 32, + 128, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 128, + 512, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f( + Conv2d( + 512, + 1024, (kernel_size, 1), (stride, 1), + padding=(get_padding(5, 1), 0))), + norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0))) + + def forward(self, x): + fmap = [] + + # 1d to 2d + b, c, t = x.shape + if t % self.period != 0: # pad first + n_pad = self.period - (t % self.period) + x = F.pad(x, (0, n_pad), 'reflect') + t = t + n_pad + x = x.view(b, c, t // self.period, self.period) + + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiPeriodDiscriminator(torch.nn.Module): + + def __init__(self): + super(MultiPeriodDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorP(2), + DiscriminatorP(3), + DiscriminatorP(5), + DiscriminatorP(7), + DiscriminatorP(11), + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorS(torch.nn.Module): + + def __init__(self, use_spectral_norm=False): + super(DiscriminatorS, self).__init__() + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f(Conv1d(1, 128, 15, 1, padding=7)), + norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)), + norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)), + norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)), + norm_f(Conv1d(1024, 1024, 5, 1, padding=2)), + ]) + self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1)) + + def forward(self, x): + fmap = [] + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = torch.flatten(x, 1, -1) + + return x, fmap + + +class MultiScaleDiscriminator(torch.nn.Module): + + def __init__(self): + super(MultiScaleDiscriminator, self).__init__() + self.discriminators = nn.ModuleList([ + DiscriminatorS(use_spectral_norm=True), + DiscriminatorS(), + DiscriminatorS(), + ]) + self.meanpools = nn.ModuleList( + [DWT1DForward(wave='db3', J=1), + DWT1DForward(wave='db3', J=1)]) + self.convs = nn.ModuleList([ + weight_norm(Conv1d(2, 1, 15, 1, padding=7)), + weight_norm(Conv1d(2, 1, 15, 1, padding=7)) + ]) + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + if i != 0: + yl, yh = self.meanpools[i - 1](y) + y = torch.cat([yl, yh[0]], dim=1) + y = self.convs[i - 1](y) + y = F.leaky_relu(y, LRELU_SLOPE) + + yl_hat, yh_hat = self.meanpools[i - 1](y_hat) + y_hat = torch.cat([yl_hat, yh_hat[0]], dim=1) + y_hat = self.convs[i - 1](y_hat) + y_hat = F.leaky_relu(y_hat, LRELU_SLOPE) + + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +class DiscriminatorSTFT(torch.nn.Module): + + def __init__(self, + kernel_size=11, + stride=2, + use_spectral_norm=False, + fft_size=1024, + shift_size=120, + win_length=600, + window='hann_window'): + super(DiscriminatorSTFT, self).__init__() + self.fft_size = fft_size + self.shift_size = shift_size + self.win_length = win_length + norm_f = weight_norm if use_spectral_norm is False else spectral_norm + self.convs = nn.ModuleList([ + norm_f( + Conv2d( + fft_size // 2 + 1, + 32, (15, 1), (1, 1), + padding=(get_padding(15, 1), 0))), + norm_f( + Conv2d( + 32, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(9, 1), 0))), + norm_f( + Conv2d( + 32, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(9, 1), 0))), + norm_f( + Conv2d( + 32, + 32, (kernel_size, 1), (stride, 1), + padding=(get_padding(9, 1), 0))), + norm_f(Conv2d(32, 32, (5, 1), (1, 1), padding=(2, 0))), + ]) + self.conv_post = norm_f(Conv2d(32, 1, (3, 1), (1, 1), padding=(1, 0))) + self.register_buffer('window', getattr(torch, window)(win_length)) + + def forward(self, wav): + wav = torch.squeeze(wav, 1) + x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length, + self.window) + x = torch.transpose(x_mag, 2, 1).unsqueeze(-1) + fmap = [] + for layer in self.convs: + x = layer(x) + x = F.leaky_relu(x, LRELU_SLOPE) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + x = x.squeeze(-1) + + return x, fmap + + +class MultiSTFTDiscriminator(torch.nn.Module): + + def __init__( + self, + fft_sizes=[1024, 2048, 512], + hop_sizes=[120, 240, 50], + win_lengths=[600, 1200, 240], + window='hann_window', + ): + super(MultiSTFTDiscriminator, self).__init__() + self.discriminators = nn.ModuleList() + for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths): + self.discriminators += [ + DiscriminatorSTFT(fft_size=fs, shift_size=ss, win_length=wl) + ] + + def forward(self, y, y_hat): + y_d_rs = [] + y_d_gs = [] + fmap_rs = [] + fmap_gs = [] + for i, d in enumerate(self.discriminators): + y_d_r, fmap_r = d(y) + y_d_g, fmap_g = d(y_hat) + y_d_rs.append(y_d_r) + fmap_rs.append(fmap_r) + y_d_gs.append(y_d_g) + fmap_gs.append(fmap_g) + + return y_d_rs, y_d_gs, fmap_rs, fmap_gs + + +def feature_loss(fmap_r, fmap_g): + loss = 0 + for dr, dg in zip(fmap_r, fmap_g): + for rl, gl in zip(dr, dg): + loss += torch.mean(torch.abs(rl - gl)) + + return loss * 2 + + +def discriminator_loss(disc_real_outputs, disc_generated_outputs): + loss = 0 + r_losses = [] + g_losses = [] + for dr, dg in zip(disc_real_outputs, disc_generated_outputs): + r_loss = torch.mean((1 - dr)**2) + g_loss = torch.mean(dg**2) + loss += (r_loss + g_loss) + r_losses.append(r_loss.item()) + g_losses.append(g_loss.item()) + + return loss, r_losses, g_losses + + +def generator_loss(disc_outputs): + loss = 0 + gen_losses = [] + for dg in disc_outputs: + temp_loss = torch.mean((1 - dg)**2) + gen_losses.append(temp_loss) + loss += temp_loss + + return loss, gen_losses diff --git a/modelscope/models/audio/tts/vocoder/models/utils.py b/modelscope/models/audio/tts/vocoder/models/utils.py new file mode 100755 index 00000000..03e1ef8c --- /dev/null +++ b/modelscope/models/audio/tts/vocoder/models/utils.py @@ -0,0 +1,59 @@ +import glob +import os + +import matplotlib +import matplotlib.pylab as plt +import torch +from torch.nn.utils import weight_norm + +matplotlib.use('Agg') + + +def plot_spectrogram(spectrogram): + fig, ax = plt.subplots(figsize=(10, 2)) + im = ax.imshow( + spectrogram, aspect='auto', origin='lower', interpolation='none') + plt.colorbar(im, ax=ax) + + fig.canvas.draw() + plt.close() + + return fig + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + m.weight.data.normal_(mean, std) + + +def apply_weight_norm(m): + classname = m.__class__.__name__ + if classname.find('Conv') != -1: + weight_norm(m) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +def load_checkpoint(filepath, device): + assert os.path.isfile(filepath) + print("Loading '{}'".format(filepath)) + checkpoint_dict = torch.load(filepath, map_location=device) + print('Complete.') + return checkpoint_dict + + +def save_checkpoint(filepath, obj): + print('Saving checkpoint to {}'.format(filepath)) + torch.save(obj, filepath) + print('Complete.') + + +def scan_checkpoint(cp_dir, prefix): + pattern = os.path.join(cp_dir, prefix + '????????') + cp_list = glob.glob(pattern) + if len(cp_list) == 0: + return None + return sorted(cp_list)[-1] diff --git a/modelscope/models/base.py b/modelscope/models/base.py index 88b1e3b0..ab0d22cc 100644 --- a/modelscope/models/base.py +++ b/modelscope/models/base.py @@ -62,4 +62,6 @@ class Model(ABC): if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'): model_cfg.type = model_cfg.model_type model_cfg.model_dir = local_model_dir + for k, v in kwargs.items(): + model_cfg.k = v return build_model(model_cfg, task_name) diff --git a/modelscope/pipelines/audio/__init__.py b/modelscope/pipelines/audio/__init__.py index eaa31c7c..20c7710a 100644 --- a/modelscope/pipelines/audio/__init__.py +++ b/modelscope/pipelines/audio/__init__.py @@ -1 +1,2 @@ from .linear_aec_pipeline import LinearAECPipeline +from .text_to_speech_pipeline import * # noqa F403 diff --git a/modelscope/pipelines/audio/text_to_speech_pipeline.py b/modelscope/pipelines/audio/text_to_speech_pipeline.py new file mode 100644 index 00000000..ecd9daac --- /dev/null +++ b/modelscope/pipelines/audio/text_to_speech_pipeline.py @@ -0,0 +1,46 @@ +import time +from typing import Any, Dict, List + +import numpy as np + +from modelscope.models import Model +from modelscope.models.audio.tts.am import SambertNetHifi16k +from modelscope.models.audio.tts.vocoder import Hifigan16k +from modelscope.pipelines.base import Pipeline +from modelscope.pipelines.builder import PIPELINES +from modelscope.preprocessors import TextToTacotronSymbols, build_preprocessor +from modelscope.utils.constant import Fields, Tasks + +__all__ = ['TextToSpeechSambertHifigan16kPipeline'] + + +@PIPELINES.register_module( + Tasks.text_to_speech, module_name=r'tts-sambert-hifigan-16k') +class TextToSpeechSambertHifigan16kPipeline(Pipeline): + + def __init__(self, + config_file: str = None, + model: List[Model] = None, + preprocessor: TextToTacotronSymbols = None, + **kwargs): + super().__init__( + config_file=config_file, + model=model, + preprocessor=preprocessor, + **kwargs) + assert len(model) == 2, 'model number should be 2' + self._am = model[0] + self._vocoder = model[1] + self._preprocessor = preprocessor + + def forward(self, inputs: Dict[str, Any]) -> Dict[str, np.ndarray]: + texts = inputs['texts'] + audio_total = np.empty((0), dtype='int16') + for line in texts: + line = line.strip().split('\t') + audio = self._vocoder.forward(self._am.forward(line[1])) + audio_total = np.append(audio_total, audio, axis=0) + return {'output': audio_total} + + def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]: + return inputs diff --git a/modelscope/preprocessors/__init__.py b/modelscope/preprocessors/__init__.py index 5db5b407..e3ae4c40 100644 --- a/modelscope/preprocessors/__init__.py +++ b/modelscope/preprocessors/__init__.py @@ -6,3 +6,4 @@ from .builder import PREPROCESSORS, build_preprocessor from .common import Compose from .image import LoadImage, load_image from .nlp import * # noqa F403 +from .text_to_speech import * # noqa F403 diff --git a/modelscope/preprocessors/text_to_speech.py b/modelscope/preprocessors/text_to_speech.py new file mode 100644 index 00000000..fd41b752 --- /dev/null +++ b/modelscope/preprocessors/text_to_speech.py @@ -0,0 +1,53 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import io +from typing import Any, Dict, Union + +import ttsfrd + +from modelscope.fileio import File +from modelscope.models.audio.tts.frontend import GenericTtsFrontend +from modelscope.models.base import Model +from modelscope.utils.audio.tts_exceptions import * # noqa F403 +from modelscope.utils.constant import Fields +from .base import Preprocessor +from .builder import PREPROCESSORS + +__all__ = ['TextToTacotronSymbols', 'text_to_tacotron_symbols'] + + +@PREPROCESSORS.register_module( + Fields.audio, module_name=r'text_to_tacotron_symbols') +class TextToTacotronSymbols(Preprocessor): + """extract tacotron symbols from text. + + Args: + res_path (str): TTS frontend resource url + lang_type (str): language type, valid values are "pinyin" and "chenmix" + """ + + def __init__(self, model_name, lang_type='pinyin'): + self._frontend_model = Model.from_pretrained( + model_name, lang_type=lang_type) + assert self._frontend_model is not None, 'load model from pretained failed' + + def __call__(self, data: str) -> Dict[str, Any]: + """Call functions to load text and get tacotron symbols. + + Args: + input (str): text with utf-8 + Returns: + symbos (list[str]): texts in tacotron symbols format. + """ + return self._frontend_model.forward(data) + + +def text_to_tacotron_symbols(text='', path='./', lang='pinyin'): + """ simple interface to transform text to tacotron symbols + + Args: + text (str): input text + path (str): resource path + lang (str): language type from one of "pinyin" and "chenmix" + """ + transform = TextToTacotronSymbols(path, lang) + return transform(text) diff --git a/modelscope/utils/audio/__init__.py b/modelscope/utils/audio/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/modelscope/utils/audio/tts_exceptions.py b/modelscope/utils/audio/tts_exceptions.py new file mode 100644 index 00000000..1ca731c3 --- /dev/null +++ b/modelscope/utils/audio/tts_exceptions.py @@ -0,0 +1,42 @@ +""" +Define TTS exceptions +""" + + +class TtsException(Exception): + """ + TTS exception class. + """ + pass + + +class TtsFrontendException(TtsException): + """ + TTS frontend module level exceptions. + """ + pass + + +class TtsFrontendInitializeFailedException(TtsFrontendException): + """ + If tts frontend resource is invalid or not exist, this exception will be raised. + """ + pass + + +class TtsFrontendLanguageTypeInvalidException(TtsFrontendException): + """ + If language type is invalid, this exception will be raised. + """ + + +class TtsVocoderException(TtsException): + """ + Vocoder exception + """ + + +class TtsVocoderMelspecShapeMismatchException(TtsVocoderException): + """ + If vocoder's input melspec shape mismatch, this exception will be raised. + """ diff --git a/modelscope/utils/registry.py b/modelscope/utils/registry.py index 319e54cb..b26b899d 100644 --- a/modelscope/utils/registry.py +++ b/modelscope/utils/registry.py @@ -67,7 +67,6 @@ class Registry(object): if module_name in self._modules[group_key]: raise KeyError(f'{module_name} is already registered in ' f'{self._name}[{group_key}]') - self._modules[group_key][module_name] = module_cls module_cls.group_key = group_key diff --git a/requirements.txt b/requirements.txt index 39eb5e23..b9b4a1c4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,4 +2,5 @@ -r requirements/pipeline.txt -r requirements/multi-modal.txt -r requirements/nlp.txt +-r requirements/audio.txt -r requirements/cv.txt diff --git a/requirements/audio.txt b/requirements/audio.txt new file mode 100644 index 00000000..140836a8 --- /dev/null +++ b/requirements/audio.txt @@ -0,0 +1,26 @@ +#tts +h5py==2.10.0 +#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp36-cp36m-linux_x86_64.whl +https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp37-cp37m-linux_x86_64.whl +https://swap.oss-cn-hangzhou.aliyuncs.com/Jiaqi%2Fmaas%2Ftts%2Frequirements%2Fpytorch_wavelets-1.3.0-py3-none-any.whl?Expires=1685688388&OSSAccessKeyId=LTAI4Ffebq4d9jTVDwiSbY4L&Signature=jcQbg5EZ%2Bdys3%2F4BRn3srrKLdIg%3D +#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp38-cp38-linux_x86_64.whl +#https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/TTS/requirements/ttsfrd-0.0.1-cp39-cp39-linux_x86_64.whl +inflect +keras==2.2.4 +librosa +lxml +matplotlib +nara_wpe +numpy==1.18.* +protobuf==3.20.* +ptflops +PyWavelets>=1.0.0 +scikit-learn==0.23.2 +sox +tensorboard +tensorflow==1.15.* +torch==1.10.* +torchaudio +torchvision +tqdm +unidecode diff --git a/tests/pipelines/test_text_to_speech.py b/tests/pipelines/test_text_to_speech.py new file mode 100644 index 00000000..c9b988a1 --- /dev/null +++ b/tests/pipelines/test_text_to_speech.py @@ -0,0 +1,60 @@ +import time +import unittest + +import json +import tensorflow as tf +# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. +# A segmentation fault may be raise by pytorch cpp library +# if 'import tensorflow' in front of 'import torch'. +# Puting a 'import torch' here can bypass this incompatibility. +import torch +from scipy.io.wavfile import write + +from modelscope.fileio import File +from modelscope.models import Model, build_model +from modelscope.models.audio.tts.am import SambertNetHifi16k +from modelscope.models.audio.tts.vocoder import AttrDict, Hifigan16k +from modelscope.pipelines import pipeline +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.constant import Fields, InputFields, Tasks +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase): + + def test_pipeline(self): + lang_type = 'pinyin' + text = '明天天气怎么样' + preprocessor_model_id = 'damo/speech_binary_tts_frontend_resource' + am_model_id = 'damo/speech_sambert16k_tts_zhitian_emo' + voc_model_id = 'damo/speech_hifigan16k_tts_zhitian_emo' + + cfg_preprocessor = dict( + type='text_to_tacotron_symbols', + model_name=preprocessor_model_id, + lang_type=lang_type) + preprocessor = build_preprocessor(cfg_preprocessor, Fields.audio) + self.assertTrue(preprocessor is not None) + + am = Model.from_pretrained(am_model_id) + self.assertTrue(am is not None) + + voc = Model.from_pretrained(voc_model_id) + self.assertTrue(voc is not None) + + sambert_tts = pipeline( + pipeline_name='tts-sambert-hifigan-16k', + config_file='', + model=[am, voc], + preprocessor=preprocessor) + self.assertTrue(sambert_tts is not None) + + output = sambert_tts(text) + self.assertTrue(len(output['output']) > 0) + write('output.wav', 16000, output['output']) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/preprocessors/test_text_to_speech.py b/tests/preprocessors/test_text_to_speech.py new file mode 100644 index 00000000..18b66987 --- /dev/null +++ b/tests/preprocessors/test_text_to_speech.py @@ -0,0 +1,28 @@ +import shutil +import unittest + +from modelscope.preprocessors import build_preprocessor +from modelscope.utils.constant import Fields, InputFields +from modelscope.utils.logger import get_logger + +logger = get_logger() + + +class TtsPreprocessorTest(unittest.TestCase): + + def test_preprocess(self): + lang_type = 'pinyin' + text = '今天天气不错,我们去散步吧。' + cfg = dict( + type='text_to_tacotron_symbols', + model_name='damo/speech_binary_tts_frontend_resource', + lang_type=lang_type) + preprocessor = build_preprocessor(cfg, Fields.audio) + output = preprocessor(text) + self.assertTrue(output) + for line in output['texts']: + print(line) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/run.py b/tests/run.py index 9f5d62a7..a904ba8e 100644 --- a/tests/run.py +++ b/tests/run.py @@ -7,6 +7,12 @@ import sys import unittest from fnmatch import fnmatch +# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch. +# A segmentation fault may be raise by pytorch cpp library +# if 'import tensorflow' in front of 'import torch'. +# Puting a 'import torch' here can bypass this incompatibility. +import torch + from modelscope.utils.logger import get_logger from modelscope.utils.test_utils import set_test_level, test_level