mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
Add ability to use custom paths
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
__pycache__/
|
||||
*.wav
|
||||
_temp/
|
||||
_temp/
|
||||
models/
|
||||
@@ -124,10 +124,12 @@ def _md5(fname):
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
def _get_ckpt_path(model_type, use_small=False, path=None):
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
if path is None:
|
||||
path = CACHE_DIR
|
||||
return os.path.join(path, f"{model_name}.pt")
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
@@ -141,10 +143,11 @@ def _grab_best_device(use_gpu=True):
|
||||
|
||||
|
||||
def _download(from_hf_path, file_name, to_local_path):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
destination_file_name = to_local_path.split("/")[-1]
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||
os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
|
||||
to_local_path = to_local_path.replace("\\", "/")
|
||||
path = '/'.join(to_local_path.split("/")[:-1])
|
||||
os.makedirs(path, exist_ok=True)
|
||||
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)
|
||||
os.replace(os.path.join(path, file_name), to_local_path)
|
||||
|
||||
class InferenceContext:
|
||||
def __init__(self, benchmark=False):
|
||||
@@ -250,8 +253,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
return model
|
||||
|
||||
|
||||
def _load_codec_model(device):
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
def _load_codec_model(device, path=None):
|
||||
model = EncodecModel.encodec_model_24khz(repository=path)
|
||||
model.set_target_bandwidth(6.0)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
@@ -259,7 +262,7 @@ def _load_codec_model(device):
|
||||
return model
|
||||
|
||||
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text", path=None):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
@@ -271,7 +274,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
models_devices[model_key] = device
|
||||
device = "cpu"
|
||||
if model_key not in models or force_reload:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
@@ -282,7 +285,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def load_codec_model(use_gpu=True, force_reload=False):
|
||||
def load_codec_model(use_gpu=True, force_reload=False, path=None):
|
||||
global models
|
||||
global models_devices
|
||||
device = _grab_best_device(use_gpu=use_gpu)
|
||||
@@ -295,7 +298,7 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
||||
device = "cpu"
|
||||
if model_key not in models or force_reload:
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_codec_model(device)
|
||||
model = _load_codec_model(device, path=path)
|
||||
models[model_key] = model
|
||||
models[model_key].to(device)
|
||||
return models[model_key]
|
||||
@@ -310,6 +313,7 @@ def preload_models(
|
||||
fine_use_small=False,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
path=None,
|
||||
):
|
||||
"""Load all the necessary models for the pipeline."""
|
||||
if _grab_best_device() == "cpu" and (
|
||||
@@ -317,18 +321,19 @@ def preload_models(
|
||||
):
|
||||
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||
_ = load_model(
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
path=path,
|
||||
)
|
||||
_ = load_model(
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
|
||||
)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path)
|
||||
|
||||
|
||||
####
|
||||
|
||||
@@ -59,7 +59,7 @@
|
||||
"# get seconds of audio\n",
|
||||
"seconds = wav.shape[-1] / model.sample_rate\n",
|
||||
"# generate semantic tokens\n",
|
||||
"semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)"
|
||||
"semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) # not 100% sure on this part"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -142,7 +142,8 @@
|
||||
" fine_use_gpu=True,\n",
|
||||
" fine_use_small=False,\n",
|
||||
" codec_use_gpu=True,\n",
|
||||
" force_reload=False\n",
|
||||
" force_reload=False,\n",
|
||||
" path=\"models\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
@@ -20,9 +20,32 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Downloading suno/bark to models\\343256c8e687c94554ef9f091bb93192.pt\n",
|
||||
"models\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"application/vnd.jupyter.widget-view+json": {
|
||||
"model_id": "eb3e09a8f3704a57b7ba9344d3b1b938",
|
||||
"version_major": 2,
|
||||
"version_minor": 0
|
||||
},
|
||||
"text/plain": [
|
||||
"Downloading text_2.pt: 0%| | 0.00/5.35G [00:00<?, ?B/s]"
|
||||
]
|
||||
},
|
||||
"metadata": {},
|
||||
"output_type": "display_data"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# download and load all models\n",
|
||||
"preload_models(\n",
|
||||
@@ -33,7 +56,8 @@
|
||||
" fine_use_gpu=True,\n",
|
||||
" fine_use_small=False,\n",
|
||||
" codec_use_gpu=True,\n",
|
||||
" force_reload=False\n",
|
||||
" force_reload=False,\n",
|
||||
" path=\"models\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
|
||||
@@ -180,6 +180,26 @@
|
||||
"In conclusion, the human journey is one of discovery, driven by our innate curiosity and desire to understand the world around us. From the dawn of our species to the present day, we have continued to explore, learn, and adapt, pushing the boundaries of what is known and possible. As we continue to unravel the mysteries of the cosmos, our spirit.\"\"\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# download and load all models\n",
|
||||
"preload_models(\n",
|
||||
" text_use_gpu=True,\n",
|
||||
" text_use_small=False,\n",
|
||||
" coarse_use_gpu=True,\n",
|
||||
" coarse_use_small=False,\n",
|
||||
" fine_use_gpu=True,\n",
|
||||
" fine_use_small=False,\n",
|
||||
" codec_use_gpu=True,\n",
|
||||
" force_reload=False,\n",
|
||||
" path=\"models\"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
|
||||
Reference in New Issue
Block a user