diff --git a/gradio_app.py b/gradio_app.py index 41f64a5..8d2dae6 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, transcribe_model = WhisperxModel(whisper_model_name, align_model) voicecraft_name = f"{voicecraft_model_name}.pth" - model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") + model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") phn2num = model.args.phn2num config = model.args model.to(device) diff --git a/inference_speech_editing.ipynb b/inference_speech_editing.ipynb index 4502966..022e3ba 100644 --- a/inference_speech_editing.ipynb +++ b/inference_speech_editing.ipynb @@ -203,8 +203,8 @@ "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "\n", "# the new way of loading the model, with huggingface, recommended\n", - "from models.voicecraft import VoiceCraftHF\n", - "model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", + "from models import voicecraft\n", + "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "phn2num = model.args.phn2num\n", "config = vars(model.args)\n", "model.to(device)\n", diff --git a/inference_tts.ipynb b/inference_tts.ipynb index f9a1862..e9712ca 100644 --- a/inference_tts.ipynb +++ b/inference_tts.ipynb @@ -74,8 +74,8 @@ "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "\n", "# the new way of loading the model, with huggingface, recommended\n", - "from models.voicecraft import VoiceCraftHF\n", - "model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", + "from models import voicecraft\n", + "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "phn2num = model.args.phn2num\n", "config = vars(model.args)\n", "model.to(device)\n", diff --git a/models/voicecraft.py b/models/voicecraft.py index 9bb3393..508e55f 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -3,6 +3,7 @@ import random import numpy as np import logging import argparse, copy +from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F @@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): -class VoiceCraft(nn.Module): - def __init__(self, args): +class VoiceCraft( + nn.Module, + PyTorchModelHubMixin, + library_name="voicecraft", + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["text-to-speech"], + ): + def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft": + # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json + # Won't affect instance initialization + if args is not None: + if config is not None: + raise ValueError("Cannot provide both `args` and `config`.") + config = vars(args) + return super().__new__(cls, args=args, config=config, **kwargs) + + def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None): super().__init__() + + # If loaded from HF Hub => convert config.json to Namespace args before initializing + if args is None: + if config is None: + raise ValueError("Either `args` or `config` must be provided.") + args = Namespace(**config) + self.args = copy.copy(args) self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) if not getattr(self.args, "special_first", False): @@ -100,7 +123,7 @@ class VoiceCraft(nn.Module): if self.args.eos > 0: assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] - if type(self.args.audio_vocab_size) == str: + if isinstance(self.args.audio_vocab_size, str): self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.n_text_tokens = self.args.text_vocab_size + 1 @@ -1414,9 +1437,3 @@ class VoiceCraft(nn.Module): flatten_gen = flatten_gen - int(self.args.n_special) return res, flatten_gen[0].unsqueeze(0) - - -class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]): - def __init__(self, config: dict): - args = Namespace(**config) - super().__init__(args) \ No newline at end of file