Update pre_kmeans_hubert.py to latest version

This commit is contained in:
Mylo
2023-05-26 22:40:09 +02:00
committed by GitHub
parent 6e4d02bfbf
commit 05b9106f00

View File

@@ -1,4 +1,11 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer """
Modified HuBERT model without kmeans.
Original author: https://github.com/lucidrains/
Modified by: https://www.github.com/gitmylo/
License: MIT
"""
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
from pathlib import Path from pathlib import Path
@@ -6,8 +13,6 @@ import torch
from torch import nn from torch import nn
from einops import pack, unpack from einops import pack, unpack
import joblib
import fairseq import fairseq
from torchaudio.functional import resample from torchaudio.functional import resample
@@ -37,13 +42,17 @@ class CustomHubert(nn.Module):
checkpoint_path, checkpoint_path,
target_sample_hz=16000, target_sample_hz=16000,
seq_len_multiple_of=None, seq_len_multiple_of=None,
output_layer=9 output_layer=9,
device=None
): ):
super().__init__() super().__init__()
self.target_sample_hz = target_sample_hz self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer self.output_layer = output_layer
if device is not None:
self.to(device)
model_path = Path(checkpoint_path) model_path = Path(checkpoint_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist' assert model_path.exists(), f'path {checkpoint_path} does not exist'
@@ -52,6 +61,9 @@ class CustomHubert(nn.Module):
load_model_input = {checkpoint_path: checkpoint} load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
if device is not None:
model[0].to(device)
self.model = model[0] self.model = model[0]
self.model.eval() self.model.eval()