diff --git a/inference_tts.ipynb b/inference_tts.ipynb index f18270c..e54368c 100644 --- a/inference_tts.ipynb +++ b/inference_tts.ipynb @@ -32,6 +32,7 @@ "import torchaudio\n", "import numpy as np\n", "import random\n", + "from argparse import Namespace\n", "\n", "from data.tokenizer import (\n", " AudioTokenizer,\n", @@ -45,7 +46,7 @@ "metadata": {}, "outputs": [], "source": [ - "# install MFA models and dictionaries if you haven't done so already\n", + "# install MFA models and dictionaries if you haven't done so already, already done in the dockerfile or envrionment setup\n", "!source ~/.bashrc && \\\n", " conda activate voicecraft && \\\n", " mfa model download dictionary english_us_arpa && \\\n", @@ -61,28 +62,38 @@ "# load model, encodec, and phn2num\n", "# # load model, tokenizer, and other necessary files\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", + "\n", + "# the old way of loading the model\n", "from models import voicecraft\n", - "#import models.voicecraft as voicecraft\n", - "voicecraft_name=\"gigaHalfLibri330M_TTSEnhanced_max16s.pth\" # or giga330M.pth, giga830M.pth\n", "ckpt_fn =f\"./pretrained_models/{voicecraft_name}\"\n", - "encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n", "if not os.path.exists(ckpt_fn):\n", " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/{voicecraft_name}\\?download\\=true\")\n", " os.system(f\"mv {voicecraft_name}\\?download\\=true ./pretrained_models/{voicecraft_name}\")\n", - "if not os.path.exists(encodec_fn):\n", - " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n", - " os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n", - "\n", "ckpt = torch.load(ckpt_fn, map_location=\"cpu\")\n", "model = voicecraft.VoiceCraft(ckpt[\"config\"])\n", "model.load_state_dict(ckpt[\"model\"])\n", + "phn2num = ckpt['phn2num']\n", + "config = vars(ckpt['config'])\n", "model.to(device)\n", "model.eval()\n", "\n", - "phn2num = ckpt['phn2num']\n", + "# # the new way of loading the model, with huggingface, this doesn't work yet\n", + "# from models.voicecraft import VoiceCraftHF\n", + "# model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", + "# phn2num = model.args.phn2num # or model.args['phn2num']?\n", + "# config = model.config\n", + "# model.to(device)\n", + "# model.eval()\n", "\n", - "text_tokenizer = TextTokenizer(backend=\"espeak\")\n", - "audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n" + "\n", + "encodec_fn = \"./pretrained_models/encodec_4cb2048_giga.th\"\n", + "if not os.path.exists(encodec_fn):\n", + " os.system(f\"wget https://huggingface.co/pyp1/VoiceCraft/resolve/main/encodec_4cb2048_giga.th\")\n", + " os.system(f\"mv encodec_4cb2048_giga.th ./pretrained_models/encodec_4cb2048_giga.th\")\n", + "audio_tokenizer = AudioTokenizer(signature=encodec_fn, device=device) # will also put the neural codec model on gpu\n", + "\n", + "text_tokenizer = TextTokenizer(backend=\"espeak\")\n" ] }, { @@ -148,7 +159,7 @@ "\n", "# NOTE adjust the below three arguments if the generation is not as good\n", "stop_repetition = 3 # NOTE if the model generate long silence, reduce the stop_repetition to 3, 2 or even 1\n", - "sample_batch_size = 2 # for gigaHalfLibri330M_TTSEnhanced_max16s.pth, 1 or 2 should be fine since the model is trained to do TTS, for the other two models, might need a higher number. NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n", + "sample_batch_size = 4 # NOTE: if the if there are long silence or unnaturally strecthed words, increase sample_batch_size to 5 or higher. What this will do to the model is that the model will run sample_batch_size examples of the same audio, and pick the one that's the shortest. So if the speech rate of the generated is too fast change it to a smaller number.\n", "seed = 1 # change seed if you are still unhappy with the result\n", "\n", "def seed_everything(seed):\n", @@ -163,7 +174,7 @@ "\n", "decode_config = {'top_k': top_k, 'top_p': top_p, 'temperature': temperature, 'stop_repetition': stop_repetition, 'kvcache': kvcache, \"codec_audio_sr\": codec_audio_sr, \"codec_sr\": codec_sr, \"silence_tokens\": silence_tokens, \"sample_batch_size\": sample_batch_size}\n", "from inference_tts_scale import inference_one_sample\n", - "concated_audio, gen_audio = inference_one_sample(model, ckpt[\"config\"], phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n", + "concated_audio, gen_audio = inference_one_sample(model, Namespace(**config), phn2num, text_tokenizer, audio_tokenizer, audio_fn, target_transcript, device, decode_config, prompt_end_frame)\n", " \n", "# save segments for comparison\n", "concated_audio, gen_audio = concated_audio[0].cpu(), gen_audio[0].cpu()\n", @@ -190,6 +201,13 @@ "\n", "# you are might get warnings like WARNING:phonemizer:words count mismatch on 300.0% of the lines (3/1), this can be safely ignored" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/models/voicecraft.py b/models/voicecraft.py index cda380a..19b0a25 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -462,6 +462,8 @@ class VoiceCraft(nn.Module): before padding. """ x, x_lens, y, y_lens = batch["x"], batch["x_lens"], batch["y"], batch["y_lens"] + if len(x) == 0: + return None x = x[:, :x_lens.max()] # this deal with gradient accumulation, where x_lens.max() might not be longer than the length of the current slice of x y = y[:, :, :y_lens.max()] assert x.ndim == 2, x.shape diff --git a/steps/trainer.py b/steps/trainer.py index 0cce5de..b312258 100644 --- a/steps/trainer.py +++ b/steps/trainer.py @@ -90,6 +90,8 @@ class Trainer: cur_batch = {key: batch[key][cur_ind] for key in batch} with torch.cuda.amp.autocast(dtype=torch.float16 if self.args.precision=="float16" else torch.float32): out = self.model(cur_batch) + if out == None: + continue record_loss = out['loss'].detach().to(self.rank) top10acc = out['top10acc'].to(self.rank)