Split forward and inference Bark GPT

This commit is contained in:
Eren Gölge
2023-08-07 15:01:14 +02:00
parent 5e0686c5f1
commit ba22d09510

View File

@@ -175,7 +175,39 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
def forward(self, idx):
device = idx.device
_, t = idx.size()
assert (
t <= self.config.block_size
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
breakpoint()
# forward the GPT model itself
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
past_length = 0
past_kv = tuple([None] * len(self.transformer.h))
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for _, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
x, kv = block(x, past_kv=past_layer_kv, use_cache=False)
x = self.transformer.ln_f(x)
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x) # note: using list [-1] to preserve the time dim
return logits
def inference(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
device = idx.device
_, t = idx.size()
if past_kv is not None: