fix encodec path

This commit is contained in:
Francis LaBounty
2023-04-30 05:30:56 -06:00
parent 45165b7ad7
commit 976ca8fb55
2 changed files with 10 additions and 34 deletions

View File

@@ -3,7 +3,6 @@ import gc
import hashlib import hashlib
import os import os
import re import re
import requests
from encodec import EncodecModel from encodec import EncodecModel
import funcy import funcy
@@ -126,10 +125,10 @@ def _md5(fname):
def _get_ckpt_path(model_type, use_small=False, path=None): def _get_ckpt_path(model_type, use_small=False, path=None):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"]) model_name = REMOTE_MODEL_PATHS[model_key]["file_name"]
if path is None: if path is None:
path = CACHE_DIR path = CACHE_DIR
return os.path.join(path, f"{model_name}.pt") return os.path.join(path, f"{model_name}")
def _grab_best_device(use_gpu=True): def _grab_best_device(use_gpu=True):
@@ -253,8 +252,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
return model return model
def _load_codec_model(device, path=None): def _load_codec_model(device):
model = EncodecModel.encodec_model_24khz(repository=path) model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0) model.set_target_bandwidth(6.0)
model.eval() model.eval()
model.to(device) model.to(device)
@@ -285,7 +284,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
return models[model_key] return models[model_key]
def load_codec_model(use_gpu=True, force_reload=False, path=None): def load_codec_model(use_gpu=True, force_reload=False):
global models global models
global models_devices global models_devices
device = _grab_best_device(use_gpu=use_gpu) device = _grab_best_device(use_gpu=use_gpu)
@@ -298,7 +297,7 @@ def load_codec_model(use_gpu=True, force_reload=False, path=None):
device = "cpu" device = "cpu"
if model_key not in models or force_reload: if model_key not in models or force_reload:
clean_models(model_key=model_key) clean_models(model_key=model_key)
model = _load_codec_model(device, path=path) model = _load_codec_model(device)
models[model_key] = model models[model_key] = model
models[model_key].to(device) models[model_key].to(device)
return models[model_key] return models[model_key]
@@ -333,7 +332,7 @@ def preload_models(
_ = load_model( _ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
) )
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path) _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
#### ####

View File

@@ -2,7 +2,7 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
@@ -20,32 +20,9 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading suno/bark to models\\343256c8e687c94554ef9f091bb93192.pt\n",
"models\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "eb3e09a8f3704a57b7ba9344d3b1b938",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Downloading text_2.pt: 0%| | 0.00/5.35G [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [ "source": [
"# download and load all models\n", "# download and load all models\n",
"preload_models(\n", "preload_models(\n",