Add ability to use custom paths

This commit is contained in:
Francis LaBounty
2023-04-29 16:53:09 -06:00
parent 1818bc88da
commit 45165b7ad7
5 changed files with 73 additions and 22 deletions

View File

@@ -124,10 +124,12 @@ def _md5(fname):
return hash_md5.hexdigest()
def _get_ckpt_path(model_type, use_small=False):
def _get_ckpt_path(model_type, use_small=False, path=None):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
return os.path.join(CACHE_DIR, f"{model_name}.pt")
if path is None:
path = CACHE_DIR
return os.path.join(path, f"{model_name}.pt")
def _grab_best_device(use_gpu=True):
@@ -141,10 +143,11 @@ def _grab_best_device(use_gpu=True):
def _download(from_hf_path, file_name, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True)
destination_file_name = to_local_path.split("/")[-1]
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
to_local_path = to_local_path.replace("\\", "/")
path = '/'.join(to_local_path.split("/")[:-1])
os.makedirs(path, exist_ok=True)
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)
os.replace(os.path.join(path, file_name), to_local_path)
class InferenceContext:
def __init__(self, benchmark=False):
@@ -250,8 +253,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
return model
def _load_codec_model(device):
model = EncodecModel.encodec_model_24khz()
def _load_codec_model(device, path=None):
model = EncodecModel.encodec_model_24khz(repository=path)
model.set_target_bandwidth(6.0)
model.eval()
model.to(device)
@@ -259,7 +262,7 @@ def _load_codec_model(device):
return model
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text", path=None):
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
@@ -271,7 +274,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
@@ -282,7 +285,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
return models[model_key]
def load_codec_model(use_gpu=True, force_reload=False):
def load_codec_model(use_gpu=True, force_reload=False, path=None):
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
@@ -295,7 +298,7 @@ def load_codec_model(use_gpu=True, force_reload=False):
device = "cpu"
if model_key not in models or force_reload:
clean_models(model_key=model_key)
model = _load_codec_model(device)
model = _load_codec_model(device, path=path)
models[model_key] = model
models[model_key].to(device)
return models[model_key]
@@ -310,6 +313,7 @@ def preload_models(
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
path=None,
):
"""Load all the necessary models for the pipeline."""
if _grab_best_device() == "cpu" and (
@@ -317,18 +321,19 @@ def preload_models(
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
)
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
path=path,
)
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path)
####