mirror of
https://github.com/coqui-ai/TTS.git
synced 2026-02-25 04:31:08 +01:00
Update inference func
This commit is contained in:
@@ -102,51 +102,74 @@ def compute_average_bass_energy(audio_data, sample_rate, max_bass_freq=250):
|
||||
return bass_energy
|
||||
|
||||
|
||||
class BarkHubertAudioTokenizer():
|
||||
def __init__(self, config, lazy_load, device='cpu') -> None:
|
||||
self.config = config
|
||||
self.lazy_load = lazy_load
|
||||
if not lazy_load:
|
||||
self.load_hubert(config, device)
|
||||
|
||||
def load_hubert(self, config, device):
|
||||
hubert_manager = HubertManager()
|
||||
hubert_manager.make_sure_tokenizer_installed(model_path=self.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
|
||||
self.hubert_model = CustomHubert(checkpoint_path=self.config.LOCAL_MODEL_PATHS["hubert"]).to(device)
|
||||
self.tokenizer = HubertTokenizer.load_from_checkpoint(
|
||||
config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=device
|
||||
)
|
||||
|
||||
def encode(self, audio, device):
|
||||
"""Encode an audio file into a sequence of tokens.
|
||||
|
||||
Args:
|
||||
audio (str or Tensor): The audio to encode. In shape (B, T).
|
||||
device (str): The device to use for encoding.
|
||||
|
||||
Returns:
|
||||
Tensor: The encoded tokens.
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
audio, sr = torchaudio.load(audio)
|
||||
audio = convert_audio(audio, sr, self.config.sample_rate, 1)
|
||||
audio = audio.to(device)
|
||||
|
||||
if self.lazy_load:
|
||||
self.load_hubert(self.config, device)
|
||||
|
||||
semantic_vectors = self.hubert_model.forward(audio, flatten=False, input_sample_hz=self.config.sample_rate)
|
||||
semantic_tokens = self.tokenizer.get_token(semantic_vectors)
|
||||
semantic_tokens = semantic_tokens
|
||||
return semantic_tokens
|
||||
|
||||
|
||||
def generate_voice(
|
||||
audio,
|
||||
model,
|
||||
output_path,
|
||||
output_path=None,
|
||||
):
|
||||
"""Generate a new voice from a given audio and text prompt.
|
||||
"""Generate a new voice from a given audioZ.
|
||||
|
||||
Args:
|
||||
audio (np.ndarray): The audio to use as a base for the new voice.
|
||||
text (str): Transcription of the audio you are clonning.
|
||||
model (BarkModel): The BarkModel to use for generating the new voice.
|
||||
output_path (str): The path to save the generated voice to.
|
||||
output_path (str): The path to save the generated voice to. If None, return computed tokens.
|
||||
"""
|
||||
if isinstance(audio, str):
|
||||
audio, sr = torchaudio.load(audio)
|
||||
audio = convert_audio(audio, sr, model.config.sample_rate, model.encodec.channels)
|
||||
audio = audio.unsqueeze(0).to(model.device)
|
||||
|
||||
with torch.no_grad():
|
||||
encoded_frames = model.encodec.encode(audio)
|
||||
codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
|
||||
# Coarse and fine tokens
|
||||
fine_tokens, coarse_tokens = model.generate_coarse_fine_tokens(audio)
|
||||
|
||||
# move codes to cpu
|
||||
codes = codes.cpu().numpy()
|
||||
# Semantic tokens
|
||||
semantic_tokens = model.generate_semantic_tokens(audio).cpu().numpy()
|
||||
|
||||
# generate semantic tokens
|
||||
# Load the HuBERT model
|
||||
hubert_manager = HubertManager()
|
||||
# hubert_manager.make_sure_hubert_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert"])
|
||||
hubert_manager.make_sure_tokenizer_installed(model_path=model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"])
|
||||
|
||||
hubert_model = CustomHubert(checkpoint_path=model.config.LOCAL_MODEL_PATHS["hubert"]).to(model.device)
|
||||
|
||||
# Load the CustomTokenizer model
|
||||
tokenizer = HubertTokenizer.load_from_checkpoint(
|
||||
model.config.LOCAL_MODEL_PATHS["hubert_tokenizer"], map_location=model.device
|
||||
)
|
||||
# semantic_tokens = model.text_to_semantic(
|
||||
# text, max_gen_duration_s=seconds, top_k=50, top_p=0.95, temp=0.7
|
||||
# ) # not 100%
|
||||
semantic_vectors = hubert_model.forward(audio[0], input_sample_hz=model.config.sample_rate)
|
||||
semantic_tokens = tokenizer.get_token(semantic_vectors)
|
||||
semantic_tokens = semantic_tokens.cpu().numpy()
|
||||
|
||||
np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
|
||||
if output_path is not None:
|
||||
np.savez(
|
||||
output_path, fine_prompt=fine_tokens, coarse_prompt=coarse_tokens[:2, :], semantic_prompt=semantic_tokens
|
||||
)
|
||||
else:
|
||||
return {"fine_prompt": fine_tokens, "coarse_prompt": coarse_tokens, "semantic_prompt": semantic_tokens}
|
||||
|
||||
|
||||
def generate_text_semantic(
|
||||
@@ -162,7 +185,7 @@ def generate_text_semantic(
|
||||
allow_early_stop=True,
|
||||
base=None,
|
||||
use_kv_caching=True,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
"""Generate semantic tokens from text.
|
||||
|
||||
@@ -242,7 +265,7 @@ def generate_text_semantic(
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
logits, kv_cache = model.semantic_model(
|
||||
logits, kv_cache = model.semantic_model.inference(
|
||||
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
|
||||
)
|
||||
relevant_logits = logits[0, 0, : model.config.SEMANTIC_VOCAB_SIZE]
|
||||
|
||||
Reference in New Issue
Block a user