diff --git a/hubert/pre_kmeans_hubert.py b/hubert/pre_kmeans_hubert.py index 93f82fe..e7a76a4 100644 --- a/hubert/pre_kmeans_hubert.py +++ b/hubert/pre_kmeans_hubert.py @@ -1,4 +1,11 @@ -# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer +""" +Modified HuBERT model without kmeans. +Original author: https://github.com/lucidrains/ +Modified by: https://www.github.com/gitmylo/ +License: MIT +""" + +# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py from pathlib import Path @@ -6,8 +13,6 @@ import torch from torch import nn from einops import pack, unpack -import joblib - import fairseq from torchaudio.functional import resample @@ -37,13 +42,17 @@ class CustomHubert(nn.Module): checkpoint_path, target_sample_hz=16000, seq_len_multiple_of=None, - output_layer=9 + output_layer=9, + device=None ): super().__init__() self.target_sample_hz = target_sample_hz self.seq_len_multiple_of = seq_len_multiple_of self.output_layer = output_layer + if device is not None: + self.to(device) + model_path = Path(checkpoint_path) assert model_path.exists(), f'path {checkpoint_path} does not exist' @@ -52,6 +61,9 @@ class CustomHubert(nn.Module): load_model_input = {checkpoint_path: checkpoint} model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) + if device is not None: + model[0].to(device) + self.model = model[0] self.model.eval()