mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 11:48:09 +01:00
Add k/v caching for autoregressive generation
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,2 +1,2 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
|
.venv
|
||||||
@@ -10,6 +10,7 @@ def text_to_semantic(
|
|||||||
history_prompt: Optional[str] = None,
|
history_prompt: Optional[str] = None,
|
||||||
temp: float = 0.7,
|
temp: float = 0.7,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
|
use_kv_caching = False,
|
||||||
):
|
):
|
||||||
"""Generate semantic array from text.
|
"""Generate semantic array from text.
|
||||||
|
|
||||||
@@ -27,6 +28,7 @@ def text_to_semantic(
|
|||||||
history_prompt=history_prompt,
|
history_prompt=history_prompt,
|
||||||
temp=temp,
|
temp=temp,
|
||||||
silent=silent,
|
silent=silent,
|
||||||
|
use_kv_caching=use_kv_caching
|
||||||
)
|
)
|
||||||
return x_semantic
|
return x_semantic
|
||||||
|
|
||||||
@@ -37,6 +39,7 @@ def semantic_to_waveform(
|
|||||||
temp: float = 0.7,
|
temp: float = 0.7,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
output_full: bool = False,
|
output_full: bool = False,
|
||||||
|
use_kv_caching = False
|
||||||
):
|
):
|
||||||
"""Generate audio array from semantic input.
|
"""Generate audio array from semantic input.
|
||||||
|
|
||||||
@@ -55,6 +58,7 @@ def semantic_to_waveform(
|
|||||||
history_prompt=history_prompt,
|
history_prompt=history_prompt,
|
||||||
temp=temp,
|
temp=temp,
|
||||||
silent=silent,
|
silent=silent,
|
||||||
|
use_kv_caching=use_kv_caching
|
||||||
)
|
)
|
||||||
fine_tokens = generate_fine(
|
fine_tokens = generate_fine(
|
||||||
coarse_tokens,
|
coarse_tokens,
|
||||||
@@ -88,6 +92,7 @@ def generate_audio(
|
|||||||
waveform_temp: float = 0.7,
|
waveform_temp: float = 0.7,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
output_full: bool = False,
|
output_full: bool = False,
|
||||||
|
use_kv_caching = False
|
||||||
):
|
):
|
||||||
"""Generate audio array from input text.
|
"""Generate audio array from input text.
|
||||||
|
|
||||||
@@ -103,7 +108,7 @@ def generate_audio(
|
|||||||
numpy audio array at sample frequency 24khz
|
numpy audio array at sample frequency 24khz
|
||||||
"""
|
"""
|
||||||
semantic_tokens = text_to_semantic(
|
semantic_tokens = text_to_semantic(
|
||||||
text, history_prompt=history_prompt, temp=text_temp, silent=silent,
|
text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching
|
||||||
)
|
)
|
||||||
out = semantic_to_waveform(
|
out = semantic_to_waveform(
|
||||||
semantic_tokens,
|
semantic_tokens,
|
||||||
@@ -111,6 +116,7 @@ def generate_audio(
|
|||||||
temp=waveform_temp,
|
temp=waveform_temp,
|
||||||
silent=silent,
|
silent=silent,
|
||||||
output_full=output_full,
|
output_full=output_full,
|
||||||
|
use_kv_caching=use_kv_caching
|
||||||
)
|
)
|
||||||
if output_full:
|
if output_full:
|
||||||
full_generation, audio_arr = out
|
full_generation, audio_arr = out
|
||||||
|
|||||||
@@ -359,6 +359,7 @@ def generate_text_semantic(
|
|||||||
max_gen_duration_s=None,
|
max_gen_duration_s=None,
|
||||||
allow_early_stop=True,
|
allow_early_stop=True,
|
||||||
model=None,
|
model=None,
|
||||||
|
use_kv_caching=False
|
||||||
):
|
):
|
||||||
"""Generate semantic tokens from text."""
|
"""Generate semantic tokens from text."""
|
||||||
assert isinstance(text, str)
|
assert isinstance(text, str)
|
||||||
@@ -420,8 +421,14 @@ def generate_text_semantic(
|
|||||||
pbar = tqdm.tqdm(disable=silent, total=100)
|
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||||
pbar_state = 0
|
pbar_state = 0
|
||||||
tot_generated_duration_s = 0
|
tot_generated_duration_s = 0
|
||||||
|
kv_cache = None
|
||||||
for n in range(n_tot_steps):
|
for n in range(n_tot_steps):
|
||||||
logits = model(x, merge_context=True)
|
if use_kv_caching and kv_cache is not None:
|
||||||
|
x_input = x[:, [-1]]
|
||||||
|
else:
|
||||||
|
x_input = x
|
||||||
|
|
||||||
|
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache)
|
||||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||||
if allow_early_stop:
|
if allow_early_stop:
|
||||||
relevant_logits = torch.hstack(
|
relevant_logits = torch.hstack(
|
||||||
@@ -498,6 +505,7 @@ def generate_coarse(
|
|||||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||||
sliding_window_len=60,
|
sliding_window_len=60,
|
||||||
model=None,
|
model=None,
|
||||||
|
use_kv_caching=False
|
||||||
):
|
):
|
||||||
"""Generate coarse audio codes from semantic tokens."""
|
"""Generate coarse audio codes from semantic tokens."""
|
||||||
assert (
|
assert (
|
||||||
@@ -592,11 +600,18 @@ def generate_coarse(
|
|||||||
x_coarse_in[:, -max_coarse_history:],
|
x_coarse_in[:, -max_coarse_history:],
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
kv_cache = None
|
||||||
for _ in range(sliding_window_len):
|
for _ in range(sliding_window_len):
|
||||||
if n_step >= n_steps:
|
if n_step >= n_steps:
|
||||||
continue
|
continue
|
||||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||||
logits = model(x_in)
|
|
||||||
|
if use_kv_caching and kv_cache is not None:
|
||||||
|
x_input = x_in[:, [-1]]
|
||||||
|
else:
|
||||||
|
x_input = x_in
|
||||||
|
|
||||||
|
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache)
|
||||||
logit_start_idx = (
|
logit_start_idx = (
|
||||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
.view(1, 1, config.block_size, config.block_size))
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, layer_past=None, use_cache=False):
|
||||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||||
|
|
||||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||||
@@ -52,14 +52,34 @@ class CausalSelfAttention(nn.Module):
|
|||||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
|
||||||
|
if layer_past is not None:
|
||||||
|
past_key = layer_past[0]
|
||||||
|
past_value = layer_past[1]
|
||||||
|
k = torch.cat((past_key, k), dim=-2)
|
||||||
|
v = torch.cat((past_value, v), dim=-2)
|
||||||
|
|
||||||
|
FULL_T = k.shape[-2]
|
||||||
|
|
||||||
|
if use_cache is True:
|
||||||
|
present = (k, v)
|
||||||
|
else:
|
||||||
|
present = None
|
||||||
|
|
||||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
if self.flash:
|
if self.flash:
|
||||||
# efficient attention using Flash Attention CUDA kernels
|
# efficient attention using Flash Attention CUDA kernels
|
||||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
if layer_past is not None:
|
||||||
|
# in theory the attention is still causal but because we're computing it incrementally,
|
||||||
|
# the last query can attend on all previous keys/values, which which is equivalent to non-causal
|
||||||
|
is_causal = False
|
||||||
|
else:
|
||||||
|
is_causal = True
|
||||||
|
|
||||||
|
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
|
||||||
else:
|
else:
|
||||||
# manual implementation of attention
|
# manual implementation of attention
|
||||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
|
||||||
att = F.softmax(att, dim=-1)
|
att = F.softmax(att, dim=-1)
|
||||||
att = self.attn_dropout(att)
|
att = self.attn_dropout(att)
|
||||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||||
@@ -67,7 +87,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
|
|
||||||
# output projection
|
# output projection
|
||||||
y = self.resid_dropout(self.c_proj(y))
|
y = self.resid_dropout(self.c_proj(y))
|
||||||
return y
|
return (y, present)
|
||||||
|
|
||||||
class MLP(nn.Module):
|
class MLP(nn.Module):
|
||||||
|
|
||||||
@@ -95,10 +115,11 @@ class Block(nn.Module):
|
|||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, layer_past=None, use_cache=False):
|
||||||
x = x + self.attn(self.ln_1(x))
|
attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache)
|
||||||
|
x = x + attn_output
|
||||||
x = x + self.mlp(self.ln_2(x))
|
x = x + self.mlp(self.ln_2(x))
|
||||||
return x
|
return (x, prev_kvs)
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GPTConfig:
|
class GPTConfig:
|
||||||
@@ -142,9 +163,13 @@ class GPT(nn.Module):
|
|||||||
n_params -= self.transformer.wpe.weight.numel()
|
n_params -= self.transformer.wpe.weight.numel()
|
||||||
return n_params
|
return n_params
|
||||||
|
|
||||||
def forward(self, idx, merge_context=False):
|
def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False):
|
||||||
device = idx.device
|
device = idx.device
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
|
if past_key_values is not None:
|
||||||
|
assert t == 1
|
||||||
|
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||||
|
else:
|
||||||
if merge_context:
|
if merge_context:
|
||||||
assert(idx.shape[1] >= 256+256+1)
|
assert(idx.shape[1] >= 256+256+1)
|
||||||
t = idx.shape[1] - 256
|
t = idx.shape[1] - 256
|
||||||
@@ -160,15 +185,33 @@ class GPT(nn.Module):
|
|||||||
else:
|
else:
|
||||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||||
|
|
||||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
if past_key_values is None:
|
||||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
past_length = 0
|
||||||
|
past_key_values = tuple([None] * len(self.transformer.h))
|
||||||
|
else:
|
||||||
|
past_length = past_key_values[0][0].size(-2)
|
||||||
|
|
||||||
|
if position_ids is None:
|
||||||
|
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||||
|
position_ids = position_ids.unsqueeze(0) # shape (1, t)
|
||||||
|
assert position_ids.shape == (1, t)
|
||||||
|
|
||||||
|
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||||
|
|
||||||
|
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
for block in self.transformer.h:
|
|
||||||
x = block(x)
|
presents = () if use_cache else None
|
||||||
|
|
||||||
|
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
|
||||||
|
x, kv = block(x, layer_past=layer_past, use_cache=use_cache)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
presents = presents + (kv,)
|
||||||
|
|
||||||
x = self.transformer.ln_f(x)
|
x = self.transformer.ln_f(x)
|
||||||
|
|
||||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||||
|
|
||||||
return logits
|
return (logits, presents)
|
||||||
|
|||||||
Reference in New Issue
Block a user