From 45165b7ad7efc1444a6c10217943e721354dd236 Mon Sep 17 00:00:00 2001 From: Francis LaBounty <73464335+francislabountyjr@users.noreply.github.com> Date: Sat, 29 Apr 2023 16:53:09 -0600 Subject: [PATCH] Add ability to use custom paths --- .gitignore | 3 ++- bark/generation.py | 35 ++++++++++++++++++++--------------- clone_voice.ipynb | 5 +++-- generate.ipynb | 32 ++++++++++++++++++++++++++++---- generate_chunked.ipynb | 20 ++++++++++++++++++++ 5 files changed, 73 insertions(+), 22 deletions(-) diff --git a/.gitignore b/.gitignore index c4672f7..1cf5bdf 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ __pycache__/ *.wav -_temp/ \ No newline at end of file +_temp/ +models/ \ No newline at end of file diff --git a/bark/generation.py b/bark/generation.py index 790307b..6967675 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -124,10 +124,12 @@ def _md5(fname): return hash_md5.hexdigest() -def _get_ckpt_path(model_type, use_small=False): +def _get_ckpt_path(model_type, use_small=False, path=None): model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"]) - return os.path.join(CACHE_DIR, f"{model_name}.pt") + if path is None: + path = CACHE_DIR + return os.path.join(path, f"{model_name}.pt") def _grab_best_device(use_gpu=True): @@ -141,10 +143,11 @@ def _grab_best_device(use_gpu=True): def _download(from_hf_path, file_name, to_local_path): - os.makedirs(CACHE_DIR, exist_ok=True) - destination_file_name = to_local_path.split("/")[-1] - hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR) - os.replace(os.path.join(CACHE_DIR, file_name), to_local_path) + to_local_path = to_local_path.replace("\\", "/") + path = '/'.join(to_local_path.split("/")[:-1]) + os.makedirs(path, exist_ok=True) + hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path) + os.replace(os.path.join(path, file_name), to_local_path) class InferenceContext: def __init__(self, benchmark=False): @@ -250,8 +253,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): return model -def _load_codec_model(device): - model = EncodecModel.encodec_model_24khz() +def _load_codec_model(device, path=None): + model = EncodecModel.encodec_model_24khz(repository=path) model.set_target_bandwidth(6.0) model.eval() model.to(device) @@ -259,7 +262,7 @@ def _load_codec_model(device): return model -def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"): +def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text", path=None): _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small) if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() @@ -271,7 +274,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te models_devices[model_key] = device device = "cpu" if model_key not in models or force_reload: - ckpt_path = _get_ckpt_path(model_type, use_small=use_small) + ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model @@ -282,7 +285,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te return models[model_key] -def load_codec_model(use_gpu=True, force_reload=False): +def load_codec_model(use_gpu=True, force_reload=False, path=None): global models global models_devices device = _grab_best_device(use_gpu=use_gpu) @@ -295,7 +298,7 @@ def load_codec_model(use_gpu=True, force_reload=False): device = "cpu" if model_key not in models or force_reload: clean_models(model_key=model_key) - model = _load_codec_model(device) + model = _load_codec_model(device, path=path) models[model_key] = model models[model_key].to(device) return models[model_key] @@ -310,6 +313,7 @@ def preload_models( fine_use_small=False, codec_use_gpu=True, force_reload=False, + path=None, ): """Load all the necessary models for the pipeline.""" if _grab_best_device() == "cpu" and ( @@ -317,18 +321,19 @@ def preload_models( ): logger.warning("No GPU being used. Careful, inference might be very slow!") _ = load_model( - model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path ) _ = load_model( model_type="coarse", use_gpu=coarse_use_gpu, use_small=coarse_use_small, force_reload=force_reload, + path=path, ) _ = load_model( - model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path ) - _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path) #### diff --git a/clone_voice.ipynb b/clone_voice.ipynb index 7c687bc..7ab0dbe 100644 --- a/clone_voice.ipynb +++ b/clone_voice.ipynb @@ -59,7 +59,7 @@ "# get seconds of audio\n", "seconds = wav.shape[-1] / model.sample_rate\n", "# generate semantic tokens\n", - "semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)" + "semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) # not 100% sure on this part" ] }, { @@ -142,7 +142,8 @@ " fine_use_gpu=True,\n", " fine_use_small=False,\n", " codec_use_gpu=True,\n", - " force_reload=False\n", + " force_reload=False,\n", + " path=\"models\"\n", ")" ] }, diff --git a/generate.ipynb b/generate.ipynb index 7029697..f5afe6a 100644 --- a/generate.ipynb +++ b/generate.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,9 +20,32 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading suno/bark to models\\343256c8e687c94554ef9f091bb93192.pt\n", + "models\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "eb3e09a8f3704a57b7ba9344d3b1b938", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading text_2.pt: 0%| | 0.00/5.35G [00:00