mirror of
https://github.com/jasonppy/VoiceCraft.git
synced 2026-04-03 01:36:55 +02:00
Add class
This commit is contained in:
@@ -85,7 +85,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
||||
|
||||
|
||||
|
||||
class VoiceCraft(nn.Module, PyTorchModelHubMixin):
|
||||
class VoiceCraft(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.args = copy.copy(args)
|
||||
@@ -1411,3 +1411,8 @@ class VoiceCraft(nn.Module, PyTorchModelHubMixin):
|
||||
flatten_gen = flatten_gen - int(self.args.n_special)
|
||||
|
||||
return res, flatten_gen[0].unsqueeze(0)
|
||||
|
||||
|
||||
class VoiceCraftHF(VoiceCraft, PyTorchModelHubMixin):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__(config)
|
||||
Reference in New Issue
Block a user