Merge pull request #62 from suno-ai/test_model_control

Simplify small model and gpu/cpu choice
This commit is contained in:
Georg Kucsko
2023-04-22 17:13:27 -04:00
committed by GitHub
2 changed files with 64 additions and 45 deletions

View File

@@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
## 🤖 Usage ## 🤖 Usage
```python ```python
from bark import SAMPLE_RATE, generate_audio from bark import SAMPLE_RATE, generate_audio, preload_models
from IPython.display import Audio from IPython.display import Audio
# download and load all models
preload_models()
# generate audio from text
text_prompt = """ text_prompt = """
Hello, my name is Suno. And, uh — and I like pizza. [laughs] Hello, my name is Suno. And, uh — and I like pizza. [laughs]
But I also have other interests such as playing tic tac toe. But I also have other interests such as playing tic tac toe.
""" """
audio_array = generate_audio(text_prompt) audio_array = generate_audio(text_prompt)
# play text in notebook
Audio(audio_array, rate=SAMPLE_RATE) Audio(audio_array, rate=SAMPLE_RATE)
``` ```

View File

@@ -1,4 +1,5 @@
import contextlib import contextlib
import gc
import hashlib import hashlib
import os import os
import re import re
@@ -84,23 +85,20 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
if USE_SMALL_MODELS:
REMOTE_MODEL_PATHS = { REMOTE_MODEL_PATHS = {
"text": { "text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"), "path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3", "checksum": "b3e42bcbab23b688355cd44128c4cdd3",
}, },
"coarse": { "coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d", "checksum": "5fe964825e3b0321f9d5f3857b89194d",
}, },
"fine": { "fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"), "path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90", "checksum": "5428d1befe05be2ba32195496e58dc90",
}, },
}
else:
REMOTE_MODEL_PATHS = {
"text": { "text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a", "checksum": "54afa89d65e318d4f5f80e8e8799026a",
@@ -113,7 +111,7 @@ else:
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b", "checksum": "59d184ed44e3650774a2f0503a48a97b",
}, },
} }
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
@@ -137,8 +135,9 @@ def _md5(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def _get_ckpt_path(model_type): def _get_ckpt_path(model_type, use_small=False):
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt") return os.path.join(CACHE_DIR, f"{model_name}.pt")
@@ -204,9 +203,10 @@ def clean_models(model_key=None):
if k in models: if k in models:
del models[k] del models[k]
_clear_cuda_cache() _clear_cuda_cache()
gc.collect()
def _load_model(ckpt_path, device, model_type="text"): def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if "cuda" not in device: if "cuda" not in device:
logger.warning("No GPU being used. Careful, inference might be extremely slow!") logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text": if model_type == "text":
@@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT ModelClass = FineGPT
else: else:
raise NotImplementedError() raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if ( if (
os.path.exists(ckpt_path) and os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] _md5(ckpt_path) != model_info["checksum"]
): ):
logger.warning(f"found outdated {model_type} model, removing.") logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path) os.remove(ckpt_path)
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) _download(model_info["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]
@@ -278,8 +280,8 @@ def _load_codec_model(device):
return model return model
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type) _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
if model_type not in ("text", "coarse", "fine"): if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError() raise NotImplementedError()
global models global models
@@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
device = "cuda" device = "cuda"
model_key = str(device) + f"__{model_type}" model_key = str(device) + f"__{model_type}"
if model_key not in models or force_reload: if model_key not in models or force_reload:
if ckpt_path is None: ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
ckpt_path = _get_ckpt_path(model_type)
clean_models(model_key=model_key) clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device) model = _load_model_f(ckpt_path, device)
models[model_key] = model models[model_key] = model
@@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
return models[model_key] return models[model_key]
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True): def preload_models(
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
coarse_use_small=False,
fine_use_gpu=True,
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
):
_ = load_model( _ = load_model(
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
) )
_ = load_model( _ = load_model(
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
) )
_ = load_model( _ = load_model(
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
) )
_ = load_codec_model(use_gpu=use_gpu, force_reload=True) _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
#### ####