Files
TTS/notebooks/ExtractTTSpectrogram.ipynb
Enno Hermann 3c2d5a9e03 Remove duplicate AudioProcessor code and fix ExtractTTSpectrogram.ipynb (#3230)
* chore: remove unused argument

* refactor(audio.processor): remove duplicate stft+griffin_lim

* chore(audio.processor): remove unused compute_stft_paddings

Same function available in numpy_transforms

* refactor(audio.processor): remove duplicate db_to_amp

* refactor(audio.processor): remove duplicate amp_to_db

* refactor(audio.processor): remove duplicate linear_to_mel

* refactor(audio.processor): remove duplicate mel_to_linear

* refactor(audio.processor): remove duplicate build_mel_basis

* refactor(audio.processor): remove duplicate stft_parameters

* refactor(audio.processor): use pre-/deemphasis from numpy_transforms

* refactor(audio.processor): use rms_volume_norm from numpy_transforms

* chore(audio.processor): remove duplicate assert

Already checked in numpy_transforms.compute_f0

* refactor(audio.processor): use find_endpoint from numpy_transforms

* refactor(audio.processor): use trim_silence from numpy_transforms

* refactor(audio.processor): use volume_norm from numpy_transforms

* refactor(audio.processor): use load_wav from numpy_transforms

* fix(bin.extract_tts_spectrograms): set quantization bits

* fix(ExtractTTSpectrogram.ipynb): adapt to current TTS code

Fixes #2447, #2574

* refactor(audio.processor): remove duplicate quantization methods
2023-11-16 10:57:06 +01:00

358 lines
12 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"This is a notebook to generate mel-spectrograms from a TTS model to be used in a Vocoder training."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import importlib\n",
"import os\n",
"import pickle\n",
"\n",
"import numpy as np\n",
"import soundfile as sf\n",
"import torch\n",
"from matplotlib import pylab as plt\n",
"from torch.utils.data import DataLoader\n",
"from tqdm import tqdm\n",
"\n",
"from TTS.config import load_config\n",
"from TTS.tts.configs.shared_configs import BaseDatasetConfig\n",
"from TTS.tts.datasets import load_tts_samples\n",
"from TTS.tts.datasets.dataset import TTSDataset\n",
"from TTS.tts.layers.losses import L1LossMasked\n",
"from TTS.tts.models import setup_model\n",
"from TTS.tts.utils.helpers import sequence_mask\n",
"from TTS.tts.utils.text.tokenizer import TTSTokenizer\n",
"from TTS.tts.utils.visual import plot_spectrogram\n",
"from TTS.utils.audio import AudioProcessor\n",
"from TTS.utils.audio.numpy_transforms import quantize\n",
"\n",
"%matplotlib inline\n",
"\n",
"# Configure CUDA visibility\n",
"os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Function to create directories and file names\n",
"def set_filename(wav_path, out_path):\n",
" wav_file = os.path.basename(wav_path)\n",
" file_name = wav_file.split('.')[0]\n",
" os.makedirs(os.path.join(out_path, \"quant\"), exist_ok=True)\n",
" os.makedirs(os.path.join(out_path, \"mel\"), exist_ok=True)\n",
" wavq_path = os.path.join(out_path, \"quant\", file_name)\n",
" mel_path = os.path.join(out_path, \"mel\", file_name)\n",
" return file_name, wavq_path, mel_path"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Paths and configurations\n",
"OUT_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/specs2/\"\n",
"DATA_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/\"\n",
"PHONEME_CACHE_PATH = \"/home/ubuntu/TTS/recipes/ljspeech/LJSpeech-1.1/phoneme_cache\"\n",
"DATASET = \"ljspeech\"\n",
"METADATA_FILE = \"metadata.csv\"\n",
"CONFIG_PATH = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/config.json\"\n",
"MODEL_FILE = \"/home/ubuntu/.local/share/tts/tts_models--en--ljspeech--tacotron2-DDC_ph/model_file.pth\"\n",
"BATCH_SIZE = 32\n",
"\n",
"QUANTIZE_BITS = 0 # if non-zero, quantize wav files with the given number of bits\n",
"DRY_RUN = False # if False, does not generate output files, only computes loss and visuals.\n",
"\n",
"# Check CUDA availability\n",
"use_cuda = torch.cuda.is_available()\n",
"print(\" > CUDA enabled: \", use_cuda)\n",
"\n",
"# Load the configuration\n",
"dataset_config = BaseDatasetConfig(formatter=DATASET, meta_file_train=METADATA_FILE, path=DATA_PATH)\n",
"C = load_config(CONFIG_PATH)\n",
"C.audio['do_trim_silence'] = False # IMPORTANT!!!!!!!!!!!!!!! disable to align mel specs with the wav files\n",
"ap = AudioProcessor(**C.audio)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize the tokenizer\n",
"tokenizer, C = TTSTokenizer.init_from_config(C)\n",
"\n",
"# Load the model\n",
"# TODO: multiple speakers\n",
"model = setup_model(C)\n",
"model.load_checkpoint(C, MODEL_FILE, eval=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Load data instances\n",
"meta_data_train, meta_data_eval = load_tts_samples(dataset_config)\n",
"meta_data = meta_data_train + meta_data_eval\n",
"\n",
"dataset = TTSDataset(\n",
" outputs_per_step=C[\"r\"],\n",
" compute_linear_spec=False,\n",
" ap=ap,\n",
" samples=meta_data,\n",
" tokenizer=tokenizer,\n",
" phoneme_cache_path=PHONEME_CACHE_PATH,\n",
")\n",
"loader = DataLoader(\n",
" dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Generate model outputs "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Initialize lists for storing results\n",
"file_idxs = []\n",
"metadata = []\n",
"losses = []\n",
"postnet_losses = []\n",
"criterion = L1LossMasked(seq_len_norm=C.seq_len_norm)\n",
"\n",
"# Start processing with a progress bar\n",
"log_file_path = os.path.join(OUT_PATH, \"log.txt\")\n",
"with torch.no_grad() and open(log_file_path, \"w\") as log_file:\n",
" for data in tqdm(loader, desc=\"Processing\"):\n",
" try:\n",
" # dispatch data to GPU\n",
" if use_cuda:\n",
" data[\"token_id\"] = data[\"token_id\"].cuda()\n",
" data[\"token_id_lengths\"] = data[\"token_id_lengths\"].cuda()\n",
" data[\"mel\"] = data[\"mel\"].cuda()\n",
" data[\"mel_lengths\"] = data[\"mel_lengths\"].cuda()\n",
"\n",
" mask = sequence_mask(data[\"token_id_lengths\"])\n",
" outputs = model.forward(data[\"token_id\"], data[\"token_id_lengths\"], data[\"mel\"])\n",
" mel_outputs = outputs[\"decoder_outputs\"]\n",
" postnet_outputs = outputs[\"model_outputs\"]\n",
"\n",
" # compute loss\n",
" loss = criterion(mel_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" loss_postnet = criterion(postnet_outputs, data[\"mel\"], data[\"mel_lengths\"])\n",
" losses.append(loss.item())\n",
" postnet_losses.append(loss_postnet.item())\n",
"\n",
" # compute mel specs from linear spec if the model is Tacotron\n",
" if C.model == \"Tacotron\":\n",
" mel_specs = []\n",
" postnet_outputs = postnet_outputs.data.cpu().numpy()\n",
" for b in range(postnet_outputs.shape[0]):\n",
" postnet_output = postnet_outputs[b]\n",
" mel_specs.append(torch.FloatTensor(ap.out_linear_to_mel(postnet_output.T).T).cuda())\n",
" postnet_outputs = torch.stack(mel_specs)\n",
" elif C.model == \"Tacotron2\":\n",
" postnet_outputs = postnet_outputs.detach().cpu().numpy()\n",
" alignments = outputs[\"alignments\"].detach().cpu().numpy()\n",
"\n",
" if not DRY_RUN:\n",
" for idx in range(data[\"token_id\"].shape[0]):\n",
" wav_file_path = data[\"item_idxs\"][idx]\n",
" wav = ap.load_wav(wav_file_path)\n",
" file_name, wavq_path, mel_path = set_filename(wav_file_path, OUT_PATH)\n",
" file_idxs.append(file_name)\n",
"\n",
" # quantize and save wav\n",
" if QUANTIZE_BITS > 0:\n",
" wavq = quantize(wav, QUANTIZE_BITS)\n",
" np.save(wavq_path, wavq)\n",
"\n",
" # save TTS mel\n",
" mel = postnet_outputs[idx]\n",
" mel_length = data[\"mel_lengths\"][idx]\n",
" mel = mel[:mel_length, :].T\n",
" np.save(mel_path, mel)\n",
"\n",
" metadata.append([wav_file_path, mel_path])\n",
" except Exception as e:\n",
" log_file.write(f\"Error processing data: {str(e)}\\n\")\n",
"\n",
" # Calculate and log mean losses\n",
" mean_loss = np.mean(losses)\n",
" mean_postnet_loss = np.mean(postnet_losses)\n",
" log_file.write(f\"Mean Loss: {mean_loss}\\n\")\n",
" log_file.write(f\"Mean Postnet Loss: {mean_postnet_loss}\\n\")\n",
"\n",
"# For wavernn\n",
"if not DRY_RUN:\n",
" pickle.dump(file_idxs, open(os.path.join(OUT_PATH, \"dataset_ids.pkl\"), \"wb\"))\n",
"\n",
"# For pwgan\n",
"with open(os.path.join(OUT_PATH, \"metadata.txt\"), \"w\") as f:\n",
" for wav_file_path, mel_path in metadata:\n",
" f.write(f\"{wav_file_path[0]}|{mel_path[1]+'.npy'}\\n\")\n",
"\n",
"# Print mean losses\n",
"print(f\"Mean Loss: {mean_loss}\")\n",
"print(f\"Mean Postnet Loss: {mean_postnet_loss}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Sanity Check"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"idx = 1\n",
"ap.melspectrogram(ap.load_wav(data[\"item_idxs\"][idx])).shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"wav, sr = sf.read(data[\"item_idxs\"][idx])\n",
"mel_postnet = postnet_outputs[idx][:data[\"mel_lengths\"][idx], :]\n",
"mel_decoder = mel_outputs[idx][:data[\"mel_lengths\"][idx], :].detach().cpu().numpy()\n",
"mel_truth = ap.melspectrogram(wav)\n",
"print(mel_truth.shape)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot posnet output\n",
"print(mel_postnet[:data[\"mel_lengths\"][idx], :].shape)\n",
"plot_spectrogram(mel_postnet, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot decoder output\n",
"print(mel_decoder.shape)\n",
"plot_spectrogram(mel_decoder, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot GT specgrogram\n",
"print(mel_truth.shape)\n",
"plot_spectrogram(mel_truth.T, ap)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# postnet, decoder diff\n",
"mel_diff = mel_decoder - mel_postnet\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff[:data[\"mel_lengths\"][idx],:]).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"mel_diff2 = mel_truth.T - mel_decoder\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# PLOT GT SPECTROGRAM diff\n",
"mel = postnet_outputs[idx]\n",
"mel_diff2 = mel_truth.T - mel[:mel_truth.shape[1]]\n",
"plt.figure(figsize=(16, 10))\n",
"plt.imshow(abs(mel_diff2).T,aspect=\"auto\", origin=\"lower\")\n",
"plt.colorbar()\n",
"plt.tight_layout()"
]
}
],
"metadata": {
"interpreter": {
"hash": "822ce188d9bce5372c4adbb11364eeb49293228c2224eb55307f4664778e7f56"
},
"kernelspec": {
"display_name": "Python 3.9.7 64-bit ('base': conda)",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 4
}