From 77df5104b086bc6617b41359bf027dd4e543357d Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 15 Apr 2024 11:51:25 +0200 Subject: [PATCH 1/4] Tweak VoiceCraft x HF integration This PR tweaks the HF integrations: - `VoiceCraft` tag is lowercased to `voicecraft` => not a hard requirement but makes it more consistent with other libraries on the Hub. - `voicecraft` is set as the `library_name` instead of simply a tag. This is better for taxonomy on the Hub. Regarding the integration, I also opened https://github.com/huggingface/huggingface.js/pull/626 to make it more official on the Hub. In particular, there will now be an official ` Use in VoiceCraft` button in all voicecraft models that display the code snippet to load the model. This should help users getting started with the model. It will also add a link to the voicecraft repo for the installation guide. cc @NielsRogge who opened https://github.com/jasonppy/VoiceCraft/pull/78 --- models/voicecraft.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/models/voicecraft.py b/models/voicecraft.py index 9bb3393..8811160 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -1416,7 +1416,13 @@ class VoiceCraft(nn.Module): return res, flatten_gen[0].unsqueeze(0) -class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", tags=["Text-to-Speech", "VoiceCraft"]): +class VoiceCraftHF( + VoiceCraft, + PyTorchModelHubMixin, + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["Text-to-Speech"], + library_name="voicecraft" + ): def __init__(self, config: dict): args = Namespace(**config) - super().__init__(args) \ No newline at end of file + super().__init__(args) From 943211d751f68895735a1c81a2bf9922eade4dc5 Mon Sep 17 00:00:00 2001 From: Lucain Date: Tue, 16 Apr 2024 08:47:49 +0200 Subject: [PATCH 2/4] Update models/voicecraft.py Co-authored-by: Julien Chaumond --- models/voicecraft.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/voicecraft.py b/models/voicecraft.py index 8811160..a68a38e 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -1420,7 +1420,7 @@ class VoiceCraftHF( VoiceCraft, PyTorchModelHubMixin, repo_url="https://github.com/jasonppy/VoiceCraft", - tags=["Text-to-Speech"], + tags=["text-to-speech"], library_name="voicecraft" ): def __init__(self, config: dict): From e550f614096a9ba098d2127a4e441ca497ba256f Mon Sep 17 00:00:00 2001 From: Wauplin Date: Tue, 16 Apr 2024 10:24:05 +0200 Subject: [PATCH 3/4] Deduplicate VoiceCraftHF <> VoiceCraft --- gradio_app.py | 2 +- models/voicecraft.py | 41 ++++++++++++++++++++++++++--------------- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/gradio_app.py b/gradio_app.py index 41f64a5..8d2dae6 100644 --- a/gradio_app.py +++ b/gradio_app.py @@ -92,7 +92,7 @@ def load_models(whisper_backend_name, whisper_model_name, alignment_model_name, transcribe_model = WhisperxModel(whisper_model_name, align_model) voicecraft_name = f"{voicecraft_model_name}.pth" - model = voicecraft.VoiceCraftHF.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") + model = voicecraft.VoiceCraft.from_pretrained(f"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}") phn2num = model.args.phn2num config = model.args model.to(device) diff --git a/models/voicecraft.py b/models/voicecraft.py index a68a38e..508e55f 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -3,6 +3,7 @@ import random import numpy as np import logging import argparse, copy +from typing import Dict, Optional import torch import torch.nn as nn import torch.nn.functional as F @@ -86,9 +87,31 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): -class VoiceCraft(nn.Module): - def __init__(self, args): +class VoiceCraft( + nn.Module, + PyTorchModelHubMixin, + library_name="voicecraft", + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["text-to-speech"], + ): + def __new__(cls, args: Optional[Namespace] = None, config: Optional[Dict] = None, **kwargs) -> "VoiceCraft": + # If initialized from Namespace args => convert to dict config for 'PyTorchModelHubMixin' to serialize it as config.json + # Won't affect instance initialization + if args is not None: + if config is not None: + raise ValueError("Cannot provide both `args` and `config`.") + config = vars(args) + return super().__new__(cls, args=args, config=config, **kwargs) + + def __init__(self, args: Optional[Namespace] = None, config: Optional[Dict] = None): super().__init__() + + # If loaded from HF Hub => convert config.json to Namespace args before initializing + if args is None: + if config is None: + raise ValueError("Either `args` or `config` must be provided.") + args = Namespace(**config) + self.args = copy.copy(args) self.pattern = DelayedPatternProvider(n_q=self.args.n_codebooks) if not getattr(self.args, "special_first", False): @@ -100,7 +123,7 @@ class VoiceCraft(nn.Module): if self.args.eos > 0: assert self.args.eos != self.args.audio_pad_token and self.args.eos != self.args.empty_token, self.args.eos self.eos = nn.Parameter(torch.full((self.args.n_codebooks, 1), self.args.eos, dtype=torch.long), requires_grad=False) # [K 1] - if type(self.args.audio_vocab_size) == str: + if isinstance(self.args.audio_vocab_size, str): self.args.audio_vocab_size = eval(self.args.audio_vocab_size) self.n_text_tokens = self.args.text_vocab_size + 1 @@ -1414,15 +1437,3 @@ class VoiceCraft(nn.Module): flatten_gen = flatten_gen - int(self.args.n_special) return res, flatten_gen[0].unsqueeze(0) - - -class VoiceCraftHF( - VoiceCraft, - PyTorchModelHubMixin, - repo_url="https://github.com/jasonppy/VoiceCraft", - tags=["text-to-speech"], - library_name="voicecraft" - ): - def __init__(self, config: dict): - args = Namespace(**config) - super().__init__(args) From 9dab23564719aac8531914574511d4b92f3d8ec2 Mon Sep 17 00:00:00 2001 From: jason-on-salt-a40 Date: Tue, 16 Apr 2024 08:55:35 -0700 Subject: [PATCH 4/4] better hf integration --- inference_speech_editing.ipynb | 4 ++-- inference_tts.ipynb | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/inference_speech_editing.ipynb b/inference_speech_editing.ipynb index 4502966..022e3ba 100644 --- a/inference_speech_editing.ipynb +++ b/inference_speech_editing.ipynb @@ -203,8 +203,8 @@ "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "\n", "# the new way of loading the model, with huggingface, recommended\n", - "from models.voicecraft import VoiceCraftHF\n", - "model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", + "from models import voicecraft\n", + "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "phn2num = model.args.phn2num\n", "config = vars(model.args)\n", "model.to(device)\n", diff --git a/inference_tts.ipynb b/inference_tts.ipynb index f9a1862..e9712ca 100644 --- a/inference_tts.ipynb +++ b/inference_tts.ipynb @@ -74,8 +74,8 @@ "voicecraft_name=\"giga330M.pth\" # or gigaHalfLibri330M_TTSEnhanced_max16s.pth, giga830M.pth\n", "\n", "# the new way of loading the model, with huggingface, recommended\n", - "from models.voicecraft import VoiceCraftHF\n", - "model = VoiceCraftHF.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", + "from models import voicecraft\n", + "model = voicecraft.VoiceCraft.from_pretrained(f\"pyp1/VoiceCraft_{voicecraft_name.replace('.pth', '')}\")\n", "phn2num = model.args.phn2num\n", "config = vars(model.args)\n", "model.to(device)\n",