Add k/v caching for autoregressive generation

This commit is contained in:
Zygimantas Straznickas
2023-04-20 18:39:14 -07:00
parent 874af1bae9
commit 15606ed12f
4 changed files with 94 additions and 30 deletions

View File

@@ -359,6 +359,7 @@ def generate_text_semantic(
max_gen_duration_s=None,
allow_early_stop=True,
model=None,
use_kv_caching=False
):
"""Generate semantic tokens from text."""
assert isinstance(text, str)
@@ -420,8 +421,14 @@ def generate_text_semantic(
pbar = tqdm.tqdm(disable=silent, total=100)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps):
logits = model(x, merge_context=True)
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache)
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop:
relevant_logits = torch.hstack(
@@ -498,6 +505,7 @@ def generate_coarse(
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
model=None,
use_kv_caching=False
):
"""Generate coarse audio codes from semantic tokens."""
assert (
@@ -592,11 +600,18 @@ def generate_coarse(
x_coarse_in[:, -max_coarse_history:],
]
)
kv_cache = None
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
logits = model(x_in)
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache)
logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
)