mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2026-04-03 09:46:24 +02:00
Add better voice clones and prepare for finetuning
This commit is contained in:
94
hubert/pre_kmeans_hubert.py
Normal file
94
hubert/pre_kmeans_hubert.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import pack, unpack
|
||||
|
||||
import joblib
|
||||
|
||||
import fairseq
|
||||
|
||||
from torchaudio.functional import resample
|
||||
|
||||
from audiolm_pytorch.utils import curtail_to_multiple
|
||||
|
||||
import logging
|
||||
logging.root.setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
class CustomHubert(nn.Module):
|
||||
"""
|
||||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
||||
or you can train your own
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path,
|
||||
target_sample_hz=16000,
|
||||
seq_len_multiple_of=None,
|
||||
output_layer=9
|
||||
):
|
||||
super().__init__()
|
||||
self.target_sample_hz = target_sample_hz
|
||||
self.seq_len_multiple_of = seq_len_multiple_of
|
||||
self.output_layer = output_layer
|
||||
|
||||
model_path = Path(checkpoint_path)
|
||||
|
||||
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
||||
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
load_model_input = {checkpoint_path: checkpoint}
|
||||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
||||
|
||||
self.model = model[0]
|
||||
self.model.eval()
|
||||
|
||||
@property
|
||||
def groups(self):
|
||||
return 1
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
wav_input,
|
||||
flatten=True,
|
||||
input_sample_hz=None
|
||||
):
|
||||
device = wav_input.device
|
||||
|
||||
if exists(input_sample_hz):
|
||||
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
|
||||
|
||||
if exists(self.seq_len_multiple_of):
|
||||
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
|
||||
|
||||
embed = self.model(
|
||||
wav_input,
|
||||
features_only=True,
|
||||
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
|
||||
output_layer=self.output_layer
|
||||
)
|
||||
|
||||
embed, packed_shape = pack([embed['x']], '* d')
|
||||
|
||||
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
|
||||
|
||||
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
|
||||
|
||||
if flatten:
|
||||
return codebook_indices
|
||||
|
||||
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
|
||||
return codebook_indices
|
||||
Reference in New Issue
Block a user