From 92b283c741d4110ff989e679405d805128299331 Mon Sep 17 00:00:00 2001 From: Niels Date: Sun, 7 Apr 2024 20:17:52 +0200 Subject: [PATCH] Add class --- models/voicecraft.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/models/voicecraft.py b/models/voicecraft.py index f090c66..8f87264 100644 --- a/models/voicecraft.py +++ b/models/voicecraft.py @@ -85,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0): -class VoiceCraft(nn.Module, PyTorchModelHubMixin): +class VoiceCraft(nn.Module): def __init__(self, args): super().__init__() self.args = copy.copy(args) @@ -1410,4 +1410,9 @@ class VoiceCraft(nn.Module, PyTorchModelHubMixin): res = res - int(self.args.n_special) flatten_gen = flatten_gen - int(self.args.n_special) - return res, flatten_gen[0].unsqueeze(0) \ No newline at end of file + return res, flatten_gen[0].unsqueeze(0) + + +class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin): + def __init__(self, config: dict): + super().__init__(config) \ No newline at end of file