mirror of
https://github.com/coqui-ai/TTS.git
synced 2026-02-24 20:19:54 +01:00
Split forward and inference Bark GPT
This commit is contained in:
@@ -175,7 +175,39 @@ class GPT(nn.Module):
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
def forward(self, idx):
|
||||
device = idx.device
|
||||
_, t = idx.size()
|
||||
assert (
|
||||
t <= self.config.block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
breakpoint()
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
|
||||
past_length = 0
|
||||
past_kv = tuple([None] * len(self.transformer.h))
|
||||
|
||||
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0) # shape (1, t)
|
||||
assert position_ids.shape == (1, t)
|
||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
|
||||
for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
||||
x, kv = block(x, past_kv=past_layer_kv, use_cache=False)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x) # note: using list [-1] to preserve the time dim
|
||||
|
||||
return logits
|
||||
|
||||
def inference(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
device = idx.device
|
||||
_, t = idx.size()
|
||||
if past_kv is not None:
|
||||
|
||||
Reference in New Issue
Block a user