Add v1 finetune support

This commit is contained in:
Francis LaBounty
2023-06-29 21:48:18 -06:00
parent 572de4b707
commit 50927298a0
10 changed files with 3243 additions and 45 deletions

View File

@@ -3,6 +3,7 @@ import gc
import hashlib
import os
import re
import json
from encodec import EncodecModel
import funcy
@@ -203,42 +204,81 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != model_info["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
# if (
# os.path.exists(ckpt_path) and
# _md5(ckpt_path) != model_info["checksum"]
# ):
# logger.warning(f"found outdated {model_type} model, removing.")
# os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
# check if config.json is in the same directory as the checkpoint
# if so, load it
# otherwise, assume it's in the checkpoint
config_path = os.path.join(os.path.dirname(ckpt_path), "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
model_args = json.load(f)
else:
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
gptconf = ConfigClass(**checkpoint["model_args"])
gptconf = ConfigClass(**model_args)
model = ModelClass(gptconf)
state_dict = checkpoint["model"]
if checkpoint.get("model", None) is not None:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
unwanted_suffixes = [
"lora_right_weight",
"lora_left_weight",
"lora_right_bias",
"lora_left_bias",
]
for k, v in list(state_dict.items()):
for suffix in unwanted_suffixes:
if k.endswith(suffix):
state_dict.pop(k)
# super hacky - should probably refactor this
if state_dict.get('lm_head.0.weight', None) is not None:
state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
if state_dict.get('lm_heads.0.0.weight', None) is not None:
state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
if state_dict.get('lm_heads.1.0.weight', None) is not None:
state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
if state_dict.get('lm_heads.2.0.weight', None) is not None:
state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
if state_dict.get('lm_heads.3.0.weight', None) is not None:
state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
if state_dict.get('lm_heads.4.0.weight', None) is not None:
state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
if state_dict.get('lm_heads.5.0.weight', None) is not None:
state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
if state_dict.get('lm_heads.6.0.weight', None) is not None:
state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
print(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
if checkpoint.get("best_val_loss", None) is not None:
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
@@ -273,8 +313,11 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
clean_models(model_key=model_key)
if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"):
ckpt_path = path
else:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
# clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
if model_type == "text":
@@ -306,10 +349,13 @@ def load_codec_model(use_gpu=True, force_reload=False):
def preload_models(
text_use_gpu=True,
text_use_small=False,
text_model_path=None,
coarse_use_gpu=True,
coarse_use_small=False,
coarse_model_path=None,
fine_use_gpu=True,
fine_use_small=False,
fine_model_path=None,
codec_use_gpu=True,
force_reload=False,
path=None,
@@ -320,17 +366,17 @@ def preload_models(
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path
)
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
path=path,
path=path if coarse_model_path is None else coarse_model_path,
)
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)