mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
Add v1 finetune support
This commit is contained in:
@@ -3,6 +3,7 @@ import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import json
|
||||
|
||||
from encodec import EncodecModel
|
||||
import funcy
|
||||
@@ -203,42 +204,81 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if (
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != model_info["checksum"]
|
||||
):
|
||||
logger.warning(f"found outdated {model_type} model, removing.")
|
||||
os.remove(ckpt_path)
|
||||
# if (
|
||||
# os.path.exists(ckpt_path) and
|
||||
# _md5(ckpt_path) != model_info["checksum"]
|
||||
# ):
|
||||
# logger.warning(f"found outdated {model_type} model, removing.")
|
||||
# os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
# check if config.json is in the same directory as the checkpoint
|
||||
# if so, load it
|
||||
# otherwise, assume it's in the checkpoint
|
||||
config_path = os.path.join(os.path.dirname(ckpt_path), "config.json")
|
||||
if os.path.exists(config_path):
|
||||
with open(config_path, "r") as f:
|
||||
model_args = json.load(f)
|
||||
else:
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
gptconf = ConfigClass(**checkpoint["model_args"])
|
||||
gptconf = ConfigClass(**model_args)
|
||||
model = ModelClass(gptconf)
|
||||
state_dict = checkpoint["model"]
|
||||
if checkpoint.get("model", None) is not None:
|
||||
state_dict = checkpoint["model"]
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
||||
unwanted_suffixes = [
|
||||
"lora_right_weight",
|
||||
"lora_left_weight",
|
||||
"lora_right_bias",
|
||||
"lora_left_bias",
|
||||
]
|
||||
for k, v in list(state_dict.items()):
|
||||
for suffix in unwanted_suffixes:
|
||||
if k.endswith(suffix):
|
||||
state_dict.pop(k)
|
||||
# super hacky - should probably refactor this
|
||||
if state_dict.get('lm_head.0.weight', None) is not None:
|
||||
state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
|
||||
if state_dict.get('lm_heads.0.0.weight', None) is not None:
|
||||
state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
|
||||
if state_dict.get('lm_heads.1.0.weight', None) is not None:
|
||||
state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
|
||||
if state_dict.get('lm_heads.2.0.weight', None) is not None:
|
||||
state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
|
||||
if state_dict.get('lm_heads.3.0.weight', None) is not None:
|
||||
state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
|
||||
if state_dict.get('lm_heads.4.0.weight', None) is not None:
|
||||
state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
|
||||
if state_dict.get('lm_heads.5.0.weight', None) is not None:
|
||||
state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
|
||||
if state_dict.get('lm_heads.6.0.weight', None) is not None:
|
||||
state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
print(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.get_num_params()
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
if checkpoint.get("best_val_loss", None) is not None:
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
@@ -273,8 +313,11 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
models_devices[model_key] = device
|
||||
device = "cpu"
|
||||
if model_key not in models or force_reload:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
|
||||
clean_models(model_key=model_key)
|
||||
if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"):
|
||||
ckpt_path = path
|
||||
else:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
|
||||
# clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
if model_type == "text":
|
||||
@@ -306,10 +349,13 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
||||
def preload_models(
|
||||
text_use_gpu=True,
|
||||
text_use_small=False,
|
||||
text_model_path=None,
|
||||
coarse_use_gpu=True,
|
||||
coarse_use_small=False,
|
||||
coarse_model_path=None,
|
||||
fine_use_gpu=True,
|
||||
fine_use_small=False,
|
||||
fine_model_path=None,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
path=None,
|
||||
@@ -320,17 +366,17 @@ def preload_models(
|
||||
):
|
||||
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||
_ = load_model(
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
path=path,
|
||||
path=path if coarse_model_path is None else coarse_model_path,
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path
|
||||
)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user