Add k/v caching for autoregressive generation

This commit is contained in:
Zygimantas Straznickas
2023-04-20 18:39:14 -07:00
parent 874af1bae9
commit 15606ed12f
4 changed files with 94 additions and 30 deletions

View File

@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size))
def forward(self, x):
def forward(self, x, layer_past=None, use_cache=False):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -52,14 +52,34 @@ class CausalSelfAttention(nn.Module):
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
if layer_past is not None:
past_key = layer_past[0]
past_value = layer_past[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
FULL_T = k.shape[-2]
if use_cache is True:
present = (k, v)
else:
present = None
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash:
# efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
if layer_past is not None:
# in theory the attention is still causal but because we're computing it incrementally,
# the last query can attend on all previous keys/values, which which is equivalent to non-causal
is_causal = False
else:
is_causal = True
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
else:
# manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
att = F.softmax(att, dim=-1)
att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
@@ -67,7 +87,7 @@ class CausalSelfAttention(nn.Module):
# output projection
y = self.resid_dropout(self.c_proj(y))
return y
return (y, present)
class MLP(nn.Module):
@@ -95,10 +115,11 @@ class Block(nn.Module):
self.mlp = MLP(config)
self.layer_idx = layer_idx
def forward(self, x):
x = x + self.attn(self.ln_1(x))
def forward(self, x, layer_past=None, use_cache=False):
attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache)
x = x + attn_output
x = x + self.mlp(self.ln_2(x))
return x
return (x, prev_kvs)
@dataclass
class GPTConfig:
@@ -142,33 +163,55 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False):
def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False):
device = idx.device
b, t = idx.size()
if merge_context:
assert(idx.shape[1] >= 256+256+1)
t = idx.shape[1] - 256
else:
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# forward the GPT model itself
if merge_context:
tok_emb = torch.cat([
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
self.transformer.wte(idx[:,256+256:])
], dim=1)
else:
if past_key_values is not None:
assert t == 1
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
else:
if merge_context:
assert(idx.shape[1] >= 256+256+1)
t = idx.shape[1] - 256
else:
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
# forward the GPT model itself
if merge_context:
tok_emb = torch.cat([
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
self.transformer.wte(idx[:,256+256:])
], dim=1)
else:
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.transformer.h))
else:
past_length = past_key_values[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x)
presents = () if use_cache else None
for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)):
x, kv = block(x, layer_past=layer_past, use_cache=use_cache)
if use_cache:
presents = presents + (kv,)
x = self.transformer.ln_f(x)
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
return logits
return (logits, presents)