Update pre_kmeans_hubert.py to latest version

This commit is contained in:
Mylo
2023-05-26 22:40:09 +02:00
committed by GitHub
parent 6e4d02bfbf
commit 05b9106f00

View File

@@ -1,4 +1,11 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
"""
Modified HuBERT model without kmeans.
Original author: https://github.com/lucidrains/
Modified by: https://www.github.com/gitmylo/
License: MIT
"""
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
from pathlib import Path
@@ -6,8 +13,6 @@ import torch
from torch import nn
from einops import pack, unpack
import joblib
import fairseq
from torchaudio.functional import resample
@@ -37,13 +42,17 @@ class CustomHubert(nn.Module):
checkpoint_path,
target_sample_hz=16000,
seq_len_multiple_of=None,
output_layer=9
output_layer=9,
device=None
):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
if device is not None:
self.to(device)
model_path = Path(checkpoint_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist'
@@ -52,6 +61,9 @@ class CustomHubert(nn.Module):
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
if device is not None:
model[0].to(device)
self.model = model[0]
self.model.eval()