mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
Merge pull request #32 from gitmylo/patch-1
Update voice cloning HuBERT quantizer
This commit is contained in:
@@ -1,4 +1,8 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
"""
|
||||
Custom tokenizer model.
|
||||
Author: https://www.github.com/gitmylo/
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
@@ -91,7 +95,7 @@ class CustomTokenizer(nn.Module):
|
||||
optimizer.step()
|
||||
|
||||
def save(self, path):
|
||||
info_path = os.path.basename(path) + '/.info'
|
||||
info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info'
|
||||
torch.save(self.state_dict(), path)
|
||||
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
|
||||
with ZipFile(path, 'a') as model_zip:
|
||||
@@ -112,7 +116,9 @@ class CustomTokenizer(nn.Module):
|
||||
model = CustomTokenizer()
|
||||
else:
|
||||
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
|
||||
model.load_state_dict(torch.load(path, map_location))
|
||||
model.load_state_dict(torch.load(path))
|
||||
if map_location:
|
||||
model = model.to(map_location)
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,11 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
"""
|
||||
Modified HuBERT model without kmeans.
|
||||
Original author: https://github.com/lucidrains/
|
||||
Modified by: https://www.github.com/gitmylo/
|
||||
License: MIT
|
||||
"""
|
||||
|
||||
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
@@ -6,8 +13,6 @@ import torch
|
||||
from torch import nn
|
||||
from einops import pack, unpack
|
||||
|
||||
import joblib
|
||||
|
||||
import fairseq
|
||||
|
||||
from torchaudio.functional import resample
|
||||
@@ -37,13 +42,17 @@ class CustomHubert(nn.Module):
|
||||
checkpoint_path,
|
||||
target_sample_hz=16000,
|
||||
seq_len_multiple_of=None,
|
||||
output_layer=9
|
||||
output_layer=9,
|
||||
device=None
|
||||
):
|
||||
super().__init__()
|
||||
self.target_sample_hz = target_sample_hz
|
||||
self.seq_len_multiple_of = seq_len_multiple_of
|
||||
self.output_layer = output_layer
|
||||
|
||||
if device is not None:
|
||||
self.to(device)
|
||||
|
||||
model_path = Path(checkpoint_path)
|
||||
|
||||
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
||||
@@ -52,6 +61,9 @@ class CustomHubert(nn.Module):
|
||||
load_model_input = {checkpoint_path: checkpoint}
|
||||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
||||
|
||||
if device is not None:
|
||||
model[0].to(device)
|
||||
|
||||
self.model = model[0]
|
||||
self.model.eval()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user