fix encodec path

This commit is contained in:
Francis LaBounty
2023-04-30 05:30:56 -06:00
parent 45165b7ad7
commit 976ca8fb55
2 changed files with 10 additions and 34 deletions

View File

@@ -3,7 +3,6 @@ import gc
import hashlib
import os
import re
import requests
from encodec import EncodecModel
import funcy
@@ -126,10 +125,10 @@ def _md5(fname):
def _get_ckpt_path(model_type, use_small=False, path=None):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
model_name = REMOTE_MODEL_PATHS[model_key]["file_name"]
if path is None:
path = CACHE_DIR
return os.path.join(path, f"{model_name}.pt")
return os.path.join(path, f"{model_name}")
def _grab_best_device(use_gpu=True):
@@ -253,8 +252,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
return model
def _load_codec_model(device, path=None):
model = EncodecModel.encodec_model_24khz(repository=path)
def _load_codec_model(device):
model = EncodecModel.encodec_model_24khz()
model.set_target_bandwidth(6.0)
model.eval()
model.to(device)
@@ -285,7 +284,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
return models[model_key]
def load_codec_model(use_gpu=True, force_reload=False, path=None):
def load_codec_model(use_gpu=True, force_reload=False):
global models
global models_devices
device = _grab_best_device(use_gpu=use_gpu)
@@ -298,7 +297,7 @@ def load_codec_model(use_gpu=True, force_reload=False, path=None):
device = "cpu"
if model_key not in models or force_reload:
clean_models(model_key=model_key)
model = _load_codec_model(device, path=path)
model = _load_codec_model(device)
models[model_key] = model
models[model_key].to(device)
return models[model_key]
@@ -333,7 +332,7 @@ def preload_models(
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
####