mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
Add ability to use custom paths
This commit is contained in:
@@ -124,10 +124,12 @@ def _md5(fname):
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
def _get_ckpt_path(model_type, use_small=False, path=None):
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
if path is None:
|
||||
path = CACHE_DIR
|
||||
return os.path.join(path, f"{model_name}.pt")
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
@@ -141,10 +143,11 @@ def _grab_best_device(use_gpu=True):
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name, to_local_path):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
destination_file_name = to_local_path.split("/")[-1]
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
|
||||
to_local_path = to_local_path.replace("\\", "/")
|
||||
path = '/'.join(to_local_path.split("/")[:-1])
|
||||
os.makedirs(path, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)
|
||||
os.replace(os.path.join(path, file_name), to_local_path)
|
||||
|
||||
class InferenceContext:
|
||||
def __init__(self, benchmark=False):
|
||||
@@ -250,8 +253,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
return model
|
||||
|
||||
|
||||
def _load_codec_model(device):
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
def _load_codec_model(device, path=None):
|
||||
model = EncodecModel.encodec_model_24khz(repository=path)
|
||||
model.set_target_bandwidth(6.0)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
@@ -259,7 +262,7 @@ def _load_codec_model(device):
|
||||
return model
|
||||
|
||||
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text", path=None):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
@@ -271,7 +274,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
models_devices[model_key] = device
|
||||
device = "cpu"
|
||||
if model_key not in models or force_reload:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
@@ -282,7 +285,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def load_codec_model(use_gpu=True, force_reload=False):
|
||||
def load_codec_model(use_gpu=True, force_reload=False, path=None):
|
||||
global models
|
||||
global models_devices
|
||||
device = _grab_best_device(use_gpu=use_gpu)
|
||||
@@ -295,7 +298,7 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
||||
device = "cpu"
|
||||
if model_key not in models or force_reload:
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_codec_model(device)
|
||||
model = _load_codec_model(device, path=path)
|
||||
models[model_key] = model
|
||||
models[model_key].to(device)
|
||||
return models[model_key]
|
||||
@@ -310,6 +313,7 @@ def preload_models(
|
||||
fine_use_small=False,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
path=None,
|
||||
):
|
||||
"""Load all the necessary models for the pipeline."""
|
||||
if _grab_best_device() == "cpu" and (
|
||||
@@ -317,18 +321,19 @@ def preload_models(
|
||||
):
|
||||
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||
_ = load_model(
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
path=path,
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
|
||||
)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path)
|
||||
|
||||
|
||||
####
|
||||
|
||||
Reference in New Issue
Block a user