mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
fix encodec path
This commit is contained in:
@@ -3,7 +3,6 @@ import gc
|
|||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import requests
|
|
||||||
|
|
||||||
from encodec import EncodecModel
|
from encodec import EncodecModel
|
||||||
import funcy
|
import funcy
|
||||||
@@ -126,10 +125,10 @@ def _md5(fname):
|
|||||||
|
|
||||||
def _get_ckpt_path(model_type, use_small=False, path=None):
|
def _get_ckpt_path(model_type, use_small=False, path=None):
|
||||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
|
model_name = REMOTE_MODEL_PATHS[model_key]["file_name"]
|
||||||
if path is None:
|
if path is None:
|
||||||
path = CACHE_DIR
|
path = CACHE_DIR
|
||||||
return os.path.join(path, f"{model_name}.pt")
|
return os.path.join(path, f"{model_name}")
|
||||||
|
|
||||||
|
|
||||||
def _grab_best_device(use_gpu=True):
|
def _grab_best_device(use_gpu=True):
|
||||||
@@ -253,8 +252,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def _load_codec_model(device, path=None):
|
def _load_codec_model(device):
|
||||||
model = EncodecModel.encodec_model_24khz(repository=path)
|
model = EncodecModel.encodec_model_24khz()
|
||||||
model.set_target_bandwidth(6.0)
|
model.set_target_bandwidth(6.0)
|
||||||
model.eval()
|
model.eval()
|
||||||
model.to(device)
|
model.to(device)
|
||||||
@@ -285,7 +284,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
|||||||
return models[model_key]
|
return models[model_key]
|
||||||
|
|
||||||
|
|
||||||
def load_codec_model(use_gpu=True, force_reload=False, path=None):
|
def load_codec_model(use_gpu=True, force_reload=False):
|
||||||
global models
|
global models
|
||||||
global models_devices
|
global models_devices
|
||||||
device = _grab_best_device(use_gpu=use_gpu)
|
device = _grab_best_device(use_gpu=use_gpu)
|
||||||
@@ -298,7 +297,7 @@ def load_codec_model(use_gpu=True, force_reload=False, path=None):
|
|||||||
device = "cpu"
|
device = "cpu"
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
model = _load_codec_model(device, path=path)
|
model = _load_codec_model(device)
|
||||||
models[model_key] = model
|
models[model_key] = model
|
||||||
models[model_key].to(device)
|
models[model_key].to(device)
|
||||||
return models[model_key]
|
return models[model_key]
|
||||||
@@ -333,7 +332,7 @@ def preload_models(
|
|||||||
_ = load_model(
|
_ = load_model(
|
||||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
|
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
|
||||||
)
|
)
|
||||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path)
|
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||||
|
|
||||||
|
|
||||||
####
|
####
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
"cells": [
|
"cells": [
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 1,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [],
|
"outputs": [],
|
||||||
"source": [
|
"source": [
|
||||||
@@ -20,32 +20,9 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": null,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"outputs": [
|
"outputs": [],
|
||||||
{
|
|
||||||
"name": "stdout",
|
|
||||||
"output_type": "stream",
|
|
||||||
"text": [
|
|
||||||
"Downloading suno/bark to models\\343256c8e687c94554ef9f091bb93192.pt\n",
|
|
||||||
"models\n"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"data": {
|
|
||||||
"application/vnd.jupyter.widget-view+json": {
|
|
||||||
"model_id": "eb3e09a8f3704a57b7ba9344d3b1b938",
|
|
||||||
"version_major": 2,
|
|
||||||
"version_minor": 0
|
|
||||||
},
|
|
||||||
"text/plain": [
|
|
||||||
"Downloading text_2.pt: 0%| | 0.00/5.35G [00:00<?, ?B/s]"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"metadata": {},
|
|
||||||
"output_type": "display_data"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"source": [
|
"source": [
|
||||||
"# download and load all models\n",
|
"# download and load all models\n",
|
||||||
"preload_models(\n",
|
"preload_models(\n",
|
||||||
|
|||||||
Reference in New Issue
Block a user