diff --git a/TTS/tts/layers/vits/text_encoder.py b/TTS/tts/layers/vits/text_encoder.py index 20d936cc..ab9a0739 100644 --- a/TTS/tts/layers/vits/text_encoder.py +++ b/TTS/tts/layers/vits/text_encoder.py @@ -6,8 +6,7 @@ from torch.nn import functional as F from TTS.tts.utils.helpers import sequence_mask from TTS.tts.layers.generic.normalization import LayerNorm, LayerNorm2 -# import sys -# sys.setrecursionlimit(9999999) + class AdaptiveWeightConv(nn.Module): def __init__(self, conv_module, in_channels, out_channels, kernel_size, r=0, alpha=1, dropout=0., num_classes=None, **kwargs): super(AdaptiveWeightConv, self).__init__() @@ -558,7 +557,7 @@ class TextEncoder(nn.Module): super().__init__() self.out_channels = out_channels self.hidden_channels = hidden_channels - + self.num_adaptive_weight_classes = num_adaptive_weight_classes self.emb = nn.Embedding(n_vocab, hidden_channels) nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5) @@ -582,12 +581,7 @@ class TextEncoder(nn.Module): self.proj = Conv1d(hidden_channels, out_channels * 2, 1, r=1 if num_adaptive_weight_classes else 0, num_classes=num_adaptive_weight_classes) - def forward(self, x, x_lengths, lang_emb=None, class_id=None): - """ - Shapes: - - x: :math:`[B, T]` - - x_length: :math:`[B]` - """ + def forward_mini_batch(self, x, x_lengths, lang_emb=None, class_id=None): assert x.shape[0] == x_lengths.shape[0] x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h] @@ -604,6 +598,41 @@ class TextEncoder(nn.Module): m, logs = torch.split(stats, self.out_channels, dim=1) return x, m, logs, x_mask + def forward(self, x, x_lengths, lang_emb=None, class_id=None): + """ + Shapes: + - x: :math:`[B, T]` + - x_length: :math:`[B]` + """ + batch_size = x.size(0) + if self.num_adaptive_weight_classes and batch_size > 1: + num_utter_per_class = int(batch_size/self.num_adaptive_weight_classes) + # mini batch inference for each class + outs_x = [] + outs_m = [] + outs_logs = [] + outs_x_mask = [] + + start = 0 + for i in range(self.num_adaptive_weight_classes): + start = num_utter_per_class * i + end = start + num_utter_per_class + class_id_item = class_id[start:end][0] + x_out, m_out, logs_out, x_mask_out = self.forward_mini_batch(x[start:end], x_lengths[start:end], lang_emb=lang_emb[start:end] if lang_emb else None, class_id=class_id_item) + outs_x.append(x_out) + outs_m.append(m_out) + outs_logs.append(logs_out) + outs_x_mask.append(x_mask_out) + + x = torch.stack(outs_x, dim=0).view(batch_size, *x_out.shape[1:]) + m = torch.stack(outs_m, dim=0).view(batch_size, *m_out.shape[1:]) + logs = torch.stack(outs_logs, dim=0).view(batch_size, *logs_out.shape[1:]) + x_mask = torch.stack(outs_x_mask, dim=0).view(batch_size, *x_mask_out.shape[1:]) + return x, m, logs, x_mask + else: + return self.forward_mini_batch(x, x_lengths, lang_emb=lang_emb, class_id=class_id) + + if __name__ == '__main__': txt_enc = TextEncoder( n_vocab=100, @@ -642,7 +671,7 @@ if __name__ == '__main__': kernel_size=3, dropout_p=0.0, language_emb_dim=None, - num_adaptive_weight_classes=5, + num_adaptive_weight_classes=None, ) out = txt_enc( diff --git a/TTS/tts/models/vits.py b/TTS/tts/models/vits.py index 7c8bb97b..e06e7fc2 100644 --- a/TTS/tts/models/vits.py +++ b/TTS/tts/models/vits.py @@ -35,7 +35,7 @@ from TTS.tts.utils.text.characters import BaseCharacters, BaseVocabulary, _chara from TTS.tts.utils.text.tokenizer import TTSTokenizer from TTS.tts.utils.visual import plot_alignment from TTS.utils.io import load_fsspec -from TTS.utils.samplers import BucketBatchSampler +from TTS.utils.samplers import BucketBatchSampler, PerfectBatchSampler from TTS.vocoder.models.hifigan_generator import HifiganGenerator from TTS.vocoder.utils.generic_utils import plot_results @@ -259,6 +259,7 @@ class VitsDataset(TTSDataset): super().__init__(*args, **kwargs) self.pad_id = self.tokenizer.characters.pad_id self.model_args = model_args + self.num_classes = None def __getitem__(self, idx): item = self.samples[idx] @@ -317,6 +318,13 @@ class VitsDataset(TTSDataset): """ # convert list of dicts to dict of lists B = len(batch) + # agroup samples of each class in the batch. perfect sampler produces [3,2,1,3,2,1] we need [3,3,2,2,1,1] + if self.model_args.use_perfect_class_batch_sampler: + new_batch = [] + for i in range(self.num_classes): + new_batch.extend(batch[i:B:self.num_classes]) + batch = new_batch + batch = {k: [dic[k] for dic in batch] for k in batch[0]} _, ids_sorted_decreasing = torch.sort( @@ -546,6 +554,9 @@ class VitsArgs(Coqpit): out_channels: int = 513 spec_segment_size: int = 32 hidden_channels: int = 192 + use_adaptive_weight_text_encoder: bool = False + use_perfect_class_batch_sampler: bool = False + perfect_class_batch_sampler_key: str = "" hidden_channels_ffn_text_encoder: int = 768 num_heads_text_encoder: int = 2 num_layers_text_encoder: int = 6 @@ -660,7 +671,8 @@ class Vits(BaseTTS): self.args.num_layers_text_encoder, self.args.kernel_size_text_encoder, self.args.dropout_p_text_encoder, - language_emb_dim=self.embedded_language_dim, + language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0, + num_adaptive_weight_classes=self.num_languages if self.args.use_adaptive_weight_text_encoder else None, ) self.posterior_encoder = PosteriorEncoder( @@ -690,7 +702,7 @@ class Vits(BaseTTS): self.args.dropout_p_duration_predictor, 4, cond_channels=self.embedded_speaker_dim if self.args.condition_dp_on_speaker else 0, - language_emb_dim=self.embedded_language_dim, + language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0, ) else: self.duration_predictor = DurationPredictor( @@ -699,7 +711,7 @@ class Vits(BaseTTS): 3, self.args.dropout_p_duration_predictor, cond_channels=self.embedded_speaker_dim, - language_emb_dim=self.embedded_language_dim, + language_emb_dim=self.embedded_language_dim if not self.args.use_adaptive_weight_text_encoder else 0, ) self.waveform_decoder = HifiganGenerator( @@ -794,12 +806,14 @@ class Vits(BaseTTS): if self.args.language_ids_file is not None: self.language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) - if self.args.use_language_embedding and self.language_manager: - print(" > initialization of language-embedding layers.") + if self.language_manager: self.num_languages = self.language_manager.num_languages - self.embedded_language_dim = self.args.embedded_language_dim - self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) - torch.nn.init.xavier_uniform_(self.emb_l.weight) + self.embedded_language_dim = 0 + if self.args.use_language_embedding: + print(" > initialization of language-embedding layers.") + self.embedded_language_dim = self.args.embedded_language_dim + self.emb_l = nn.Embedding(self.num_languages, self.embedded_language_dim) + torch.nn.init.xavier_uniform_(self.emb_l.weight) else: self.embedded_language_dim = 0 @@ -1016,7 +1030,7 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, class_id=lid) # posterior encoder z, m_q, logs_q, y_mask = self.posterior_encoder(y, y_lengths, g=g) @@ -1122,7 +1136,7 @@ class Vits(BaseTTS): if self.args.use_language_embedding and lid is not None: lang_emb = self.emb_l(lid).unsqueeze(-1) - x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb) + x, m_p, logs_p, x_mask = self.text_encoder(x, x_lengths, lang_emb=lang_emb, class_id=lid) if durations is None: if self.args.use_sdp: @@ -1413,7 +1427,7 @@ class Vits(BaseTTS): speaker_id = self.speaker_manager.name_to_id[speaker_name] # get language id - if hasattr(self, "language_manager") and config.use_language_embedding and language_name is not None: + if hasattr(self, "language_manager") and (config.use_language_embedding or config.use_adaptive_weight_text_encoder) and language_name is not None: language_id = self.language_manager.name_to_id[language_name] return { @@ -1482,7 +1496,7 @@ class Vits(BaseTTS): d_vectors = torch.FloatTensor(d_vectors) # get language ids from language names - if self.language_manager is not None and self.language_manager.name_to_id and self.args.use_language_embedding: + if self.language_manager is not None and self.language_manager.name_to_id and (self.args.use_language_embedding or self.args.use_adaptive_weight_text_encoder): language_ids = [self.language_manager.name_to_id[ln] for ln in batch["language_names"]] if language_ids is not None: @@ -1547,6 +1561,21 @@ class Vits(BaseTTS): return batch def get_sampler(self, config: Coqpit, dataset: TTSDataset, num_gpus=1, is_eval=False): + if self.args.use_perfect_class_batch_sampler: + batch_size = config.eval_batch_size if is_eval else config.batch_size + data_items = dataset.samples + classes = [item[self.args.perfect_class_batch_sampler_key] for item in data_items] + classes = set(classes) + dataset.num_classes = len(classes) + batch_sampler = PerfectBatchSampler( + dataset_items=data_items, + classes=classes, + batch_size=batch_size, + num_classes_in_batch=len(classes), + label_key=self.args.perfect_class_batch_sampler_key, + ) + return batch_sampler + weights = None data_items = dataset.samples if getattr(config, "use_weighted_sampler", False): @@ -1631,7 +1660,7 @@ class Vits(BaseTTS): pin_memory=False, ) else: - if num_gpus > 1: + if num_gpus > 1 and not self.args.use_perfect_class_batch_sampler: loader = DataLoader( dataset, sampler=sampler, diff --git a/TTS/tts/utils/languages.py b/TTS/tts/utils/languages.py index 1e1836b3..e45ddab8 100644 --- a/TTS/tts/utils/languages.py +++ b/TTS/tts/utils/languages.py @@ -92,7 +92,7 @@ class LanguageManager(BaseIDManager): config (Coqpit): Coqpit config. """ language_manager = None - if check_config_and_model_args(config, "use_language_embedding", True): + if check_config_and_model_args(config, "use_language_embedding", True) or check_config_and_model_args(config, "use_adaptive_weight_text_encoder", True): if config.get("language_ids_file", None): language_manager = LanguageManager(language_ids_file_path=config.language_ids_file) language_manager = LanguageManager(config=config) diff --git a/recipes/multilingual/syntacc/train_syntacc.py b/recipes/multilingual/syntacc/train_syntacc.py new file mode 100644 index 00000000..d0bb2dad --- /dev/null +++ b/recipes/multilingual/syntacc/train_syntacc.py @@ -0,0 +1,230 @@ +import os + +import torch +from trainer import Trainer, TrainerArgs + +from TTS.bin.compute_embeddings import compute_embeddings +from TTS.bin.resample import resample_files +from TTS.config.shared_configs import BaseDatasetConfig +from TTS.tts.configs.vits_config import VitsConfig +from TTS.tts.datasets import load_tts_samples +from TTS.tts.models.vits import CharactersConfig, Vits, VitsArgs, VitsAudioConfig, VitsDataset +from TTS.utils.downloaders import download_libri_tts +from torch.utils.data import DataLoader +from TTS.utils.samplers import PerfectBatchSampler +torch.set_num_threads(24) + +# pylint: disable=W0105 +""" + This recipe replicates the first experiment proposed in the CML-TTS paper (https://arxiv.org/abs/2306.10097). It uses the YourTTS model. + YourTTS model is based on the VITS model however it uses external speaker embeddings extracted from a pre-trained speaker encoder and has small architecture changes. +""" +CURRENT_PATH = os.path.dirname(os.path.abspath(__file__)) + +# Name of the run for the Trainer +RUN_NAME = "YourTTS-CML-TTS" + +# Path where you want to save the models outputs (configs, checkpoints and tensorboard logs) +OUT_PATH = os.path.dirname(os.path.abspath(__file__)) # "/raid/coqui/Checkpoints/original-YourTTS/" + +# If you want to do transfer learning and speedup your training you can set here the path to the CML-TTS available checkpoint that cam be downloaded here: https://drive.google.com/u/2/uc?id=1yDCSJ1pFZQTHhL09GMbOrdjcPULApa0p +RESTORE_PATH = None # "/raid/edresson/CML_YourTTS/checkpoints_yourtts_cml_tts_dataset/best_model.pth" # Download the checkpoint here: https://drive.google.com/u/2/uc?id=1yDCSJ1pFZQTHhL09GMbOrdjcPULApa0p + +# This paramter is useful to debug, it skips the training epochs and just do the evaluation and produce the test sentences +SKIP_TRAIN_EPOCH = False + +# Set here the batch size to be used in training and evaluation +BATCH_SIZE = 6 + +# Training Sampling rate and the target sampling rate for resampling the downloaded dataset (Note: If you change this you might need to redownload the dataset !!) +# Note: If you add new datasets, please make sure that the dataset sampling rate and this parameter are matching, otherwise resample your audios +SAMPLE_RATE = 24000 + +# Max audio length in seconds to be used in training (every audio bigger than it will be ignored) +MAX_AUDIO_LEN_IN_SECONDS = float("inf") + +# DEfine here the datasets config +esd_train_config = BaseDatasetConfig( + formatter="coqui", + dataset_name="esd", + meta_file_train="metadata_with_basic_metrics.csv", # TODO: compute emotion and d-vectors for test and evaluation splits + path="/raid/datasets/Emotion/ESD-44kHz-VAD-renormalized/", + language="en" +) + +savee_config = BaseDatasetConfig( + formatter="coqui", + dataset_name="savee", + path="/raid/datasets/SAVEE-44khz/", + meta_file_train="metadata_with_basic_metrics.csv", + language="pt" +) +game1_config = BaseDatasetConfig( + formatter="coqui", + dataset_name="game1", + path="/raid/datasets/new_game_data/game1/datasetbuilder_formatted/", + meta_file_train="metadata_with_basic_metrics.csv", + language="de", +) +DATASETS_CONFIG_LIST = [esd_train_config, savee_config, game1_config] + + +### Extract speaker embeddings +SPEAKER_ENCODER_CHECKPOINT_PATH = ( + "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/model_se.pth.tar" +) +SPEAKER_ENCODER_CONFIG_PATH = "https://github.com/coqui-ai/TTS/releases/download/speaker_encoder_model/config_se.json" + +D_VECTOR_FILES = [] # List of speaker embeddings/d-vectors to be used during the training + +# Iterates all the dataset configs checking if the speakers embeddings are already computated, if not compute it +for dataset_conf in DATASETS_CONFIG_LIST: + # Check if the embeddings weren't already computed, if not compute it + embeddings_file = os.path.join(dataset_conf.path, "H_ASP_speaker_embeddings.pth") + if not os.path.isfile(embeddings_file): + print(f">>> Computing the speaker embeddings for the {dataset_conf.dataset_name} dataset") + compute_embeddings( + SPEAKER_ENCODER_CHECKPOINT_PATH, + SPEAKER_ENCODER_CONFIG_PATH, + embeddings_file, + old_speakers_file=None, + config_dataset_path=None, + formatter_name=dataset_conf.formatter, + dataset_name=dataset_conf.dataset_name, + dataset_path=dataset_conf.path, + meta_file_train=dataset_conf.meta_file_train, + meta_file_val=dataset_conf.meta_file_val, + disable_cuda=False, + no_eval=False, + ) + D_VECTOR_FILES.append(embeddings_file) + + +# Audio config used in training. +audio_config = VitsAudioConfig( + sample_rate=SAMPLE_RATE, + hop_length=256, + win_length=1024, + fft_size=1024, + mel_fmin=0.0, + mel_fmax=None, + num_mels=80, +) + +# Init VITSArgs setting the arguments that are needed for the YourTTS model +model_args = VitsArgs( + spec_segment_size=62, + hidden_channels=192, + hidden_channels_ffn_text_encoder=768, + num_heads_text_encoder=2, + num_layers_text_encoder=10, + kernel_size_text_encoder=3, + dropout_p_text_encoder=0.1, + d_vector_file=D_VECTOR_FILES, + use_d_vector_file=True, + d_vector_dim=512, + speaker_encoder_model_path=SPEAKER_ENCODER_CHECKPOINT_PATH, + speaker_encoder_config_path=SPEAKER_ENCODER_CONFIG_PATH, + resblock_type_decoder="2", # In the paper, we accidentally trained the YourTTS using ResNet blocks type 2, if you like you can use the ResNet blocks type 1 like the VITS model + # Useful parameters to enable the Speaker Consistency Loss (SCL) described in the paper + use_speaker_encoder_as_loss=False, + # Useful parameters to enable multilingual training + use_language_embedding=False, + embedded_language_dim=4, + use_adaptive_weight_text_encoder=True, + use_perfect_class_batch_sampler=True, + perfect_class_batch_sampler_key="language" +) + +# General training config, here you can change the batch size and others useful parameters +config = VitsConfig( + output_path=OUT_PATH, + model_args=model_args, + run_name=RUN_NAME, + project_name="SYNTACC", + run_description=""" + - YourTTS with SYNTACC text encoder + """, + dashboard_logger="tensorboard", + logger_uri=None, + audio=audio_config, + batch_size=BATCH_SIZE, + batch_group_size=48, + eval_batch_size=BATCH_SIZE, + num_loader_workers=8, + eval_split_max_size=256, + print_step=50, + plot_step=100, + log_model_step=1000, + save_step=5000, + save_n_checkpoints=2, + save_checkpoints=True, + # target_loss="loss_1", + print_eval=False, + use_phonemes=False, + phonemizer="espeak", + phoneme_language="en", + compute_input_seq_cache=True, + add_blank=True, + text_cleaner="multilingual_cleaners", + characters=CharactersConfig( + characters_class="TTS.tts.models.vits.VitsCharacters", + pad="_", + eos="&", + bos="*", + blank=None, + characters="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\u00a1\u00a3\u00b7\u00b8\u00c0\u00c1\u00c2\u00c3\u00c4\u00c5\u00c7\u00c8\u00c9\u00ca\u00cb\u00cc\u00cd\u00ce\u00cf\u00d1\u00d2\u00d3\u00d4\u00d5\u00d6\u00d9\u00da\u00db\u00dc\u00df\u00e0\u00e1\u00e2\u00e3\u00e4\u00e5\u00e7\u00e8\u00e9\u00ea\u00eb\u00ec\u00ed\u00ee\u00ef\u00f1\u00f2\u00f3\u00f4\u00f5\u00f6\u00f9\u00fa\u00fb\u00fc\u0101\u0104\u0105\u0106\u0107\u010b\u0119\u0141\u0142\u0143\u0144\u0152\u0153\u015a\u015b\u0161\u0178\u0179\u017a\u017b\u017c\u020e\u04e7\u05c2\u1b20", + punctuations="\u2014!'(),-.:;?\u00bf ", + phonemes="iy\u0268\u0289\u026fu\u026a\u028f\u028ae\u00f8\u0258\u0259\u0275\u0264o\u025b\u0153\u025c\u025e\u028c\u0254\u00e6\u0250a\u0276\u0251\u0252\u1d7b\u0298\u0253\u01c0\u0257\u01c3\u0284\u01c2\u0260\u01c1\u029bpbtd\u0288\u0256c\u025fk\u0261q\u0262\u0294\u0274\u014b\u0272\u0273n\u0271m\u0299r\u0280\u2c71\u027e\u027d\u0278\u03b2fv\u03b8\u00f0sz\u0283\u0292\u0282\u0290\u00e7\u029dx\u0263\u03c7\u0281\u0127\u0295h\u0266\u026c\u026e\u028b\u0279\u027bj\u0270l\u026d\u028e\u029f\u02c8\u02cc\u02d0\u02d1\u028dw\u0265\u029c\u02a2\u02a1\u0255\u0291\u027a\u0267\u025a\u02de\u026b'\u0303' ", + is_unique=True, + is_sorted=True, + ), + phoneme_cache_path=None, + precompute_num_workers=12, + start_by_longest=True, + datasets=DATASETS_CONFIG_LIST, + cudnn_benchmark=False, + max_audio_len=SAMPLE_RATE * MAX_AUDIO_LEN_IN_SECONDS, + mixed_precision=False, + test_sentences=[ + ["Voc\u00ea ter\u00e1 a vista do topo da montanha que voc\u00ea escalar.", "ESD_0012", None, "pt"], + ["Quando voc\u00ea n\u00e3o corre nenhum risco, voc\u00ea arrisca tudo.", "ESD_0012", None, "pt"], + ], + # Enable the weighted sampler + use_weighted_sampler=True, + # Ensures that all speakers are seen in the training batch equally no matter how many samples each speaker has + # weighted_sampler_attrs={"language": 1.0, "speaker_name": 1.0}, + weighted_sampler_attrs={"language": 1.0}, + weighted_sampler_multipliers={ + # "speaker_name": { + # you can force the batching scheme to give a higher weight to a certain speaker and then this speaker will appears more frequently on the batch. + # It will speedup the speaker adaptation process. Considering the CML train dataset and "new_speaker" as the speaker name of the speaker that you want to adapt. + # The line above will make the balancer consider the "new_speaker" as 106 speakers so 1/4 of the number of speakers present on CML dataset. + # 'new_speaker': 106, # (CML tot. train speaker)/4 = (424/4) = 106 + # } + }, + # It defines the Speaker Consistency Loss (SCL) α to 9 like the YourTTS paper + speaker_encoder_loss_alpha=9.0, +) + +# Load all the datasets samples and split traning and evaluation sets +train_samples, eval_samples = load_tts_samples( + config.datasets, + eval_split=True, + eval_split_max_size=config.eval_split_max_size, + eval_split_size=config.eval_split_size, +) + +# Init the model +model = Vits.init_from_config(config) + +# Init the trainer and 🚀 +trainer = Trainer( + TrainerArgs(restore_path=RESTORE_PATH, skip_train_epoch=SKIP_TRAIN_EPOCH, start_with_eval=True), + config, + output_path=OUT_PATH, + model=model, + train_samples=train_samples, + eval_samples=eval_samples, +) +trainer.fit()