mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
Add k/v caching for autoregressive generation
This commit is contained in:
@@ -359,6 +359,7 @@ def generate_text_semantic(
|
||||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
@@ -420,8 +421,14 @@ def generate_text_semantic(
|
||||
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||
pbar_state = 0
|
||||
tot_generated_duration_s = 0
|
||||
kv_cache = None
|
||||
for n in range(n_tot_steps):
|
||||
logits = model(x, merge_context=True)
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
|
||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
@@ -498,6 +505,7 @@ def generate_coarse(
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
@@ -592,11 +600,18 @@ def generate_coarse(
|
||||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
logits = model(x_in)
|
||||
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x_in[:, [-1]]
|
||||
else:
|
||||
x_input = x_in
|
||||
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user