Merge pull request #32 from gitmylo/patch-1

Update voice cloning HuBERT quantizer
This commit is contained in:
Francis LaBounty
2023-05-26 17:07:50 -06:00
committed by GitHub
2 changed files with 25 additions and 7 deletions

View File

@@ -1,4 +1,8 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer """
Custom tokenizer model.
Author: https://www.github.com/gitmylo/
License: MIT
"""
import json import json
import os.path import os.path
@@ -91,7 +95,7 @@ class CustomTokenizer(nn.Module):
optimizer.step() optimizer.step()
def save(self, path): def save(self, path):
info_path = os.path.basename(path) + '/.info' info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info'
torch.save(self.state_dict(), path) torch.save(self.state_dict(), path)
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
with ZipFile(path, 'a') as model_zip: with ZipFile(path, 'a') as model_zip:
@@ -112,7 +116,9 @@ class CustomTokenizer(nn.Module):
model = CustomTokenizer() model = CustomTokenizer()
else: else:
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version) model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
model.load_state_dict(torch.load(path, map_location)) model.load_state_dict(torch.load(path))
if map_location:
model = model.to(map_location)
return model return model

View File

@@ -1,4 +1,11 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer """
Modified HuBERT model without kmeans.
Original author: https://github.com/lucidrains/
Modified by: https://www.github.com/gitmylo/
License: MIT
"""
# Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
from pathlib import Path from pathlib import Path
@@ -6,8 +13,6 @@ import torch
from torch import nn from torch import nn
from einops import pack, unpack from einops import pack, unpack
import joblib
import fairseq import fairseq
from torchaudio.functional import resample from torchaudio.functional import resample
@@ -37,13 +42,17 @@ class CustomHubert(nn.Module):
checkpoint_path, checkpoint_path,
target_sample_hz=16000, target_sample_hz=16000,
seq_len_multiple_of=None, seq_len_multiple_of=None,
output_layer=9 output_layer=9,
device=None
): ):
super().__init__() super().__init__()
self.target_sample_hz = target_sample_hz self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer self.output_layer = output_layer
if device is not None:
self.to(device)
model_path = Path(checkpoint_path) model_path = Path(checkpoint_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist' assert model_path.exists(), f'path {checkpoint_path} does not exist'
@@ -52,6 +61,9 @@ class CustomHubert(nn.Module):
load_model_input = {checkpoint_path: checkpoint} load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
if device is not None:
model[0].to(device)
self.model = model[0] self.model = model[0]
self.model.eval() self.model.eval()