mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
Rename variables and add comments
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
.venv
|
|
||||||
@@ -428,7 +428,7 @@ def generate_text_semantic(
|
|||||||
else:
|
else:
|
||||||
x_input = x
|
x_input = x
|
||||||
|
|
||||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache)
|
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||||
if allow_early_stop:
|
if allow_early_stop:
|
||||||
relevant_logits = torch.hstack(
|
relevant_logits = torch.hstack(
|
||||||
@@ -611,7 +611,7 @@ def generate_coarse(
|
|||||||
else:
|
else:
|
||||||
x_input = x_in
|
x_input = x_in
|
||||||
|
|
||||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache)
|
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||||
logit_start_idx = (
|
logit_start_idx = (
|
||||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
|
|||||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||||
.view(1, 1, config.block_size, config.block_size))
|
.view(1, 1, config.block_size, config.block_size))
|
||||||
|
|
||||||
def forward(self, x, layer_past=None, use_cache=False):
|
def forward(self, x, past_kv=None, use_cache=False):
|
||||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||||
|
|
||||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||||
@@ -52,9 +52,9 @@ class CausalSelfAttention(nn.Module):
|
|||||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||||
|
|
||||||
if layer_past is not None:
|
if past_kv is not None:
|
||||||
past_key = layer_past[0]
|
past_key = past_kv[0]
|
||||||
past_value = layer_past[1]
|
past_value = past_kv[1]
|
||||||
k = torch.cat((past_key, k), dim=-2)
|
k = torch.cat((past_key, k), dim=-2)
|
||||||
v = torch.cat((past_value, v), dim=-2)
|
v = torch.cat((past_value, v), dim=-2)
|
||||||
|
|
||||||
@@ -68,9 +68,11 @@ class CausalSelfAttention(nn.Module):
|
|||||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||||
if self.flash:
|
if self.flash:
|
||||||
# efficient attention using Flash Attention CUDA kernels
|
# efficient attention using Flash Attention CUDA kernels
|
||||||
if layer_past is not None:
|
if past_kv is not None:
|
||||||
# in theory the attention is still causal but because we're computing it incrementally,
|
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
|
||||||
# the last query can attend on all previous keys/values, which which is equivalent to non-causal
|
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
|
||||||
|
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
|
||||||
|
# to work around this we set is_causal=False.
|
||||||
is_causal = False
|
is_causal = False
|
||||||
else:
|
else:
|
||||||
is_causal = True
|
is_causal = True
|
||||||
@@ -115,8 +117,8 @@ class Block(nn.Module):
|
|||||||
self.mlp = MLP(config)
|
self.mlp = MLP(config)
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
|
||||||
def forward(self, x, layer_past=None, use_cache=False):
|
def forward(self, x, past_kv=None, use_cache=False):
|
||||||
attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache)
|
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
|
||||||
x = x + attn_output
|
x = x + attn_output
|
||||||
x = x + self.mlp(self.ln_2(x))
|
x = x + self.mlp(self.ln_2(x))
|
||||||
return (x, prev_kvs)
|
return (x, prev_kvs)
|
||||||
@@ -163,10 +165,10 @@ class GPT(nn.Module):
|
|||||||
n_params -= self.transformer.wpe.weight.numel()
|
n_params -= self.transformer.wpe.weight.numel()
|
||||||
return n_params
|
return n_params
|
||||||
|
|
||||||
def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False):
|
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||||
device = idx.device
|
device = idx.device
|
||||||
b, t = idx.size()
|
b, t = idx.size()
|
||||||
if past_key_values is not None:
|
if past_kv is not None:
|
||||||
assert t == 1
|
assert t == 1
|
||||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||||
else:
|
else:
|
||||||
@@ -185,11 +187,11 @@ class GPT(nn.Module):
|
|||||||
else:
|
else:
|
||||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||||
|
|
||||||
if past_key_values is None:
|
if past_kv is None:
|
||||||
past_length = 0
|
past_length = 0
|
||||||
past_key_values = tuple([None] * len(self.transformer.h))
|
past_kv = tuple([None] * len(self.transformer.h))
|
||||||
else:
|
else:
|
||||||
past_length = past_key_values[0][0].size(-2)
|
past_length = past_kv[0][0].size(-2)
|
||||||
|
|
||||||
if position_ids is None:
|
if position_ids is None:
|
||||||
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||||
@@ -201,17 +203,17 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
|
|
||||||
presents = () if use_cache else None
|
new_kv = () if use_cache else None
|
||||||
|
|
||||||
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
|
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
||||||
x, kv = block(x, layer_past=layer_past, use_cache=use_cache)
|
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
presents = presents + (kv,)
|
new_kv = new_kv + (kv,)
|
||||||
|
|
||||||
x = self.transformer.ln_f(x)
|
x = self.transformer.ln_f(x)
|
||||||
|
|
||||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||||
|
|
||||||
return (logits, presents)
|
return (logits, new_kv)
|
||||||
|
|||||||
Reference in New Issue
Block a user