Update inference func

This commit is contained in:
Eren Gölge
2023-08-07 15:00:32 +02:00
parent 4e0f9aa83c
commit d6ee537c51

View File

@@ -102,51 +102,74 @@ def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250):
return bass_energy
class BarkHubertAudioTokenizer():
def __init__(self, config, lazy_load, device='cpu') -> None:
self.config = config
self.lazy_load = lazy_load
if not lazy_load:
self.load_hubert(config, device)
def load_hubert(self, config, device):
hubert_manager = HubertManager()
hubert_manager.make_sure_tokenizer_installed(model_path=self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
self.hubert_model = CustomHubert(checkpoint_path=self.config.LOCAL_MODEL_PATHS["hubert"]).to(device)
self.tokenizer = HubertTokenizer.load_from_checkpoint(
config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=device
)
def encode(self, audio, device):
"""Encode an audio file into a sequence of tokens.
Args:
audio (str or Tensor): The audio to encode. In shape (B, T).
device (str): The device to use for encoding.
Returns:
Tensor: The encoded tokens.
"""
if isinstance(audio, str):
audio, sr = torchaudio.load(audio)
audio = convert_audio(audio, sr, self.config.sample_rate, 1)
audio = audio.to(device)
if self.lazy_load:
self.load_hubert(self.config, device)
semantic_vectors = self.hubert_model.forward(audio, flatten=False, input_sample_hz=self.config.sample_rate)
semantic_tokens = self.tokenizer.get_token(semantic_vectors)
semantic_tokens = semantic_tokens
return semantic_tokens
def generate_voice(
audio,
model,
output_path,
output_path=None,
):
"""Generate a new voice from a given audio and text prompt.
"""Generate a new voice from a given audioZ.
Args:
audio (np.ndarray): The audio to use as a base for the new voice.
text (str): Transcription of the audio you are clonning.
model (BarkModel): The BarkModel to use for generating the new voice.
output_path (str): The path to save the generated voice to.
output_path (str): The path to save the generated voice to. If None, return computed tokens.
"""
if isinstance(audio, str):
audio, sr = torchaudio.load(audio)
audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels)
audio = audio.unsqueeze(0).to(model.device)
with torch.no_grad():
encoded_frames = model.encodec.encode(audio)
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
# Coarse and fine tokens
fine_tokens, coarse_tokens = model.generate_coarse_fine_tokens(audio)
# move codes to cpu
codes = codes.cpu().numpy()
# Semantic tokens
semantic_tokens = model.generate_semantic_tokens(audio).cpu().numpy()
# generate semantic tokens
# Load the HuBERT model
hubert_manager = HubertManager()
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
# Load the CustomTokenizer model
tokenizer = HubertTokenizer.load_from_checkpoint(
model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device
)
# semantic_tokens = model.text_to_semantic(
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
# ) # not 100%
semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate)
semantic_tokens = tokenizer.get_token(semantic_vectors)
semantic_tokens = semantic_tokens.cpu().numpy()
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
if output_path is not None:
np.savez(
output_path, fine_prompt=fine_tokens, coarse_prompt=coarse_tokens[:2, :], semantic_prompt=semantic_tokens
)
else:
return {"fine_prompt": fine_tokens, "coarse_prompt": coarse_tokens, "semantic_prompt": semantic_tokens}
def generate_text_semantic(
@@ -162,7 +185,7 @@ def generate_text_semantic(
allow_early_stop=True,
base=None,
use_kv_caching=True,
**kwargs, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
"""Generate semantic tokens from text.
@@ -242,7 +265,7 @@ def generate_text_semantic(
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model.semantic_model(
logits, kv_cache = model.semantic_model.inference(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE]