diff --git a/README.md b/README.md index e292536..9e37326 100644 --- a/README.md +++ b/README.md @@ -100,6 +100,7 @@ conda install -c conda-forge montreal-forced-aligner=2.2.17 openfst=1.8.2 kaldi= # install MFA english dictionary and model mfa model download dictionary english_us_arpa mfa model download acoustic english_us_arpa +pip install huggingface_hub # conda install pocl # above gives an warning for installing pocl, not sure if really need this # to run ipynb diff --git a/models/voicecraft.py b/models/voicecraft.py index ab3cf37..cda380a 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -18,6 +18,10 @@ from .modules.transformer import ( ) from .codebooks_patterns import DelayedPatternProvider +from argparse import Namespace +from huggingface_hub import PyTorchModelHubMixin + + def top_k_top_p_filtering( logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1 ): @@ -1407,4 +1411,12 @@ class VoiceCraft(nn.Module): res = res - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special) - return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file + return res, flatten_gen[0].unsqueeze(0) + + +class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin): + repo_url="https://github.com/jasonppy/VoiceCraft", + tags=["Text-to-Speech", "VoiceCraft"] + def __init__(self, config: dict): + args = Namespace(**config) + super().__init__(args) \ No newline at end of file