Add v1 finetune support

This commit is contained in:
Francis LaBounty
2023-06-29 21:48:18 -06:00
parent 572de4b707
commit 50927298a0
10 changed files with 3243 additions and 45 deletions

View File

@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat, reduce
SEMANTIC_PAD_TOKEN = 10_000
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@@ -167,7 +167,7 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, training=False):
device = idx.device
b, t = idx.size()
if past_kv is not None:
@@ -215,19 +215,9 @@ class GPT(nn.Module):
x = self.transformer.ln_f(x)
if labels is not None:
if training:
logits = self.lm_head(x)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return logits, loss
return logits
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim