mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 19:27:57 +01:00
Update pre_kmeans_hubert.py to latest version
This commit is contained in:
@@ -1,4 +1,11 @@
|
|||||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
"""
|
||||||
|
Modified HuBERT model without kmeans.
|
||||||
|
Original author: https://github.com/lucidrains/
|
||||||
|
Modified by: https://www.github.com/gitmylo/
|
||||||
|
License: MIT
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -6,8 +13,6 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from einops import pack, unpack
|
from einops import pack, unpack
|
||||||
|
|
||||||
import joblib
|
|
||||||
|
|
||||||
import fairseq
|
import fairseq
|
||||||
|
|
||||||
from torchaudio.functional import resample
|
from torchaudio.functional import resample
|
||||||
@@ -37,13 +42,17 @@ class CustomHubert(nn.Module):
|
|||||||
checkpoint_path,
|
checkpoint_path,
|
||||||
target_sample_hz=16000,
|
target_sample_hz=16000,
|
||||||
seq_len_multiple_of=None,
|
seq_len_multiple_of=None,
|
||||||
output_layer=9
|
output_layer=9,
|
||||||
|
device=None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.target_sample_hz = target_sample_hz
|
self.target_sample_hz = target_sample_hz
|
||||||
self.seq_len_multiple_of = seq_len_multiple_of
|
self.seq_len_multiple_of = seq_len_multiple_of
|
||||||
self.output_layer = output_layer
|
self.output_layer = output_layer
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
model_path = Path(checkpoint_path)
|
model_path = Path(checkpoint_path)
|
||||||
|
|
||||||
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
||||||
@@ -52,6 +61,9 @@ class CustomHubert(nn.Module):
|
|||||||
load_model_input = {checkpoint_path: checkpoint}
|
load_model_input = {checkpoint_path: checkpoint}
|
||||||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
||||||
|
|
||||||
|
if device is not None:
|
||||||
|
model[0].to(device)
|
||||||
|
|
||||||
self.model = model[0]
|
self.model = model[0]
|
||||||
self.model.eval()
|
self.model.eval()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user