Add k/v caching for autoregressive generation

This commit is contained in:
Zygimantas Straznickas
2023-04-20 18:39:14 -07:00
parent 874af1bae9
commit 15606ed12f
4 changed files with 94 additions and 30 deletions

View File

@@ -10,6 +10,7 @@ def text_to_semantic(
history_prompt: Optional[str] = None,
temp: float = 0.7,
silent: bool = False,
use_kv_caching = False,
):
"""Generate semantic array from text.
@@ -27,6 +28,7 @@ def text_to_semantic(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
)
return x_semantic
@@ -37,6 +39,7 @@ def semantic_to_waveform(
temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from semantic input.
@@ -55,6 +58,7 @@ def semantic_to_waveform(
history_prompt=history_prompt,
temp=temp,
silent=silent,
use_kv_caching=use_kv_caching
)
fine_tokens = generate_fine(
coarse_tokens,
@@ -88,6 +92,7 @@ def generate_audio(
waveform_temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
use_kv_caching = False
):
"""Generate audio array from input text.
@@ -103,7 +108,7 @@ def generate_audio(
numpy audio array at sample frequency 24khz
"""
semantic_tokens = text_to_semantic(
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
)
out = semantic_to_waveform(
semantic_tokens,
@@ -111,6 +116,7 @@ def generate_audio(
temp=waveform_temp,
silent=silent,
output_full=output_full,
use_kv_caching=use_kv_caching
)
if output_full:
full_generation, audio_arr = out