Add v1 finetune support

This commit is contained in:
Francis LaBounty
2023-06-29 21:48:18 -06:00
parent 572de4b707
commit 50927298a0
10 changed files with 3243 additions and 45 deletions

View File

@@ -3,6 +3,7 @@ import gc
import hashlib
import os
import re
import json
from encodec import EncodecModel
import funcy
@@ -203,42 +204,81 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != model_info["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
# if (
# os.path.exists(ckpt_path) and
# _md5(ckpt_path) != model_info["checksum"]
# ):
# logger.warning(f"found outdated {model_type} model, removing.")
# os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
# check if config.json is in the same directory as the checkpoint
# if so, load it
# otherwise, assume it's in the checkpoint
config_path = os.path.join(os.path.dirname(ckpt_path), "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
model_args = json.load(f)
else:
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
gptconf = ConfigClass(**checkpoint["model_args"])
gptconf = ConfigClass(**model_args)
model = ModelClass(gptconf)
state_dict = checkpoint["model"]
if checkpoint.get("model", None) is not None:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
unwanted_suffixes = [
"lora_right_weight",
"lora_left_weight",
"lora_right_bias",
"lora_left_bias",
]
for k, v in list(state_dict.items()):
for suffix in unwanted_suffixes:
if k.endswith(suffix):
state_dict.pop(k)
# super hacky - should probably refactor this
if state_dict.get('lm_head.0.weight', None) is not None:
state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
if state_dict.get('lm_heads.0.0.weight', None) is not None:
state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
if state_dict.get('lm_heads.1.0.weight', None) is not None:
state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
if state_dict.get('lm_heads.2.0.weight', None) is not None:
state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
if state_dict.get('lm_heads.3.0.weight', None) is not None:
state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
if state_dict.get('lm_heads.4.0.weight', None) is not None:
state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
if state_dict.get('lm_heads.5.0.weight', None) is not None:
state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
if state_dict.get('lm_heads.6.0.weight', None) is not None:
state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
print(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
if checkpoint.get("best_val_loss", None) is not None:
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
@@ -273,8 +313,11 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
clean_models(model_key=model_key)
if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"):
ckpt_path = path
else:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
# clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
if model_type == "text":
@@ -306,10 +349,13 @@ def load_codec_model(use_gpu=True, force_reload=False):
def preload_models(
text_use_gpu=True,
text_use_small=False,
text_model_path=None,
coarse_use_gpu=True,
coarse_use_small=False,
coarse_model_path=None,
fine_use_gpu=True,
fine_use_small=False,
fine_model_path=None,
codec_use_gpu=True,
force_reload=False,
path=None,
@@ -320,17 +366,17 @@ def preload_models(
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path
)
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
path=path,
path=path if coarse_model_path is None else coarse_model_path,
)
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)

View File

@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat, reduce
SEMANTIC_PAD_TOKEN = 10_000
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@@ -167,7 +167,7 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, training=False):
device = idx.device
b, t = idx.size()
if past_kv is not None:
@@ -215,19 +215,9 @@ class GPT(nn.Module):
x = self.transformer.ln_f(x)
if labels is not None:
if training:
logits = self.lm_head(x)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return logits, loss
return logits
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim