diff --git a/bark/generation.py b/bark/generation.py index 6967675..d1c7a6b 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -3,7 +3,6 @@ import gc import hashlib import os import re -import requests from encodec import EncodecModel import funcy @@ -126,10 +125,10 @@ def _md5(fname): def _get_ckpt_path(model_type, use_small=False, path=None): model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type - model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"]) + model_name = REMOTE_MODEL_PATHS[model_key]["file_name"] if path is None: path = CACHE_DIR - return os.path.join(path, f"{model_name}.pt") + return os.path.join(path, f"{model_name}") def _grab_best_device(use_gpu=True): @@ -253,8 +252,8 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): return model -def _load_codec_model(device, path=None): - model = EncodecModel.encodec_model_24khz(repository=path) +def _load_codec_model(device): + model = EncodecModel.encodec_model_24khz() model.set_target_bandwidth(6.0) model.eval() model.to(device) @@ -285,7 +284,7 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te return models[model_key] -def load_codec_model(use_gpu=True, force_reload=False, path=None): +def load_codec_model(use_gpu=True, force_reload=False): global models global models_devices device = _grab_best_device(use_gpu=use_gpu) @@ -298,7 +297,7 @@ def load_codec_model(use_gpu=True, force_reload=False, path=None): device = "cpu" if model_key not in models or force_reload: clean_models(model_key=model_key) - model = _load_codec_model(device, path=path) + model = _load_codec_model(device) models[model_key] = model models[model_key].to(device) return models[model_key] @@ -333,7 +332,7 @@ def preload_models( _ = load_model( model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path ) - _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload, path=path) + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) #### diff --git a/generate.ipynb b/generate.ipynb index f5afe6a..a135af3 100644 --- a/generate.ipynb +++ b/generate.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -20,32 +20,9 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading suno/bark to models\\343256c8e687c94554ef9f091bb93192.pt\n", - "models\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "eb3e09a8f3704a57b7ba9344d3b1b938", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Downloading text_2.pt: 0%| | 0.00/5.35G [00:00