mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 19:27:57 +01:00
Merge pull request #62 from suno-ai/test_model_control
Simplify small model and gpu/cpu choice
This commit is contained in:
@@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
|
|||||||
## 🤖 Usage
|
## 🤖 Usage
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from bark import SAMPLE_RATE, generate_audio
|
from bark import SAMPLE_RATE, generate_audio, preload_models
|
||||||
from IPython.display import Audio
|
from IPython.display import Audio
|
||||||
|
|
||||||
|
# download and load all models
|
||||||
|
preload_models()
|
||||||
|
|
||||||
|
# generate audio from text
|
||||||
text_prompt = """
|
text_prompt = """
|
||||||
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
|
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
|
||||||
But I also have other interests such as playing tic tac toe.
|
But I also have other interests such as playing tic tac toe.
|
||||||
"""
|
"""
|
||||||
audio_array = generate_audio(text_prompt)
|
audio_array = generate_audio(text_prompt)
|
||||||
|
|
||||||
|
# play text in notebook
|
||||||
Audio(audio_array, rate=SAMPLE_RATE)
|
Audio(audio_array, rate=SAMPLE_RATE)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
|
import gc
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
@@ -84,36 +85,33 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
|
|||||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||||
|
|
||||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||||
if USE_SMALL_MODELS:
|
|
||||||
REMOTE_MODEL_PATHS = {
|
REMOTE_MODEL_PATHS = {
|
||||||
"text": {
|
"text_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||||
},
|
},
|
||||||
"coarse": {
|
"coarse_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||||
},
|
},
|
||||||
"fine": {
|
"fine_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||||
},
|
},
|
||||||
}
|
"text": {
|
||||||
else:
|
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||||
REMOTE_MODEL_PATHS = {
|
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||||
"text": {
|
},
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
"coarse": {
|
||||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||||
},
|
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||||
"coarse": {
|
},
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
"fine": {
|
||||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||||
},
|
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||||
"fine": {
|
},
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
}
|
||||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||||
@@ -137,8 +135,9 @@ def _md5(fname):
|
|||||||
return hash_md5.hexdigest()
|
return hash_md5.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def _get_ckpt_path(model_type):
|
def _get_ckpt_path(model_type, use_small=False):
|
||||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
|
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||||
|
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
|
||||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||||
|
|
||||||
|
|
||||||
@@ -204,9 +203,10 @@ def clean_models(model_key=None):
|
|||||||
if k in models:
|
if k in models:
|
||||||
del models[k]
|
del models[k]
|
||||||
_clear_cuda_cache()
|
_clear_cuda_cache()
|
||||||
|
gc.collect()
|
||||||
|
|
||||||
|
|
||||||
def _load_model(ckpt_path, device, model_type="text"):
|
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||||
if "cuda" not in device:
|
if "cuda" not in device:
|
||||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||||
if model_type == "text":
|
if model_type == "text":
|
||||||
@@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"):
|
|||||||
ModelClass = FineGPT
|
ModelClass = FineGPT
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||||
|
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||||
if (
|
if (
|
||||||
os.path.exists(ckpt_path) and
|
os.path.exists(ckpt_path) and
|
||||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
_md5(ckpt_path) != model_info["checksum"]
|
||||||
):
|
):
|
||||||
logger.warning(f"found outdated {model_type} model, removing.")
|
logger.warning(f"found outdated {model_type} model, removing.")
|
||||||
os.remove(ckpt_path)
|
os.remove(ckpt_path)
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
_download(model_info["path"], ckpt_path)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
# this is a hack
|
# this is a hack
|
||||||
model_args = checkpoint["model_args"]
|
model_args = checkpoint["model_args"]
|
||||||
@@ -278,8 +280,8 @@ def _load_codec_model(device):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
|
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||||
_load_model_f = funcy.partial(_load_model, model_type=model_type)
|
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||||
if model_type not in ("text", "coarse", "fine"):
|
if model_type not in ("text", "coarse", "fine"):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
global models
|
global models
|
||||||
@@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
|
|||||||
device = "cuda"
|
device = "cuda"
|
||||||
model_key = str(device) + f"__{model_type}"
|
model_key = str(device) + f"__{model_type}"
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
if ckpt_path is None:
|
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||||
ckpt_path = _get_ckpt_path(model_type)
|
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
model = _load_model_f(ckpt_path, device)
|
model = _load_model_f(ckpt_path, device)
|
||||||
models[model_key] = model
|
models[model_key] = model
|
||||||
@@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
|||||||
return models[model_key]
|
return models[model_key]
|
||||||
|
|
||||||
|
|
||||||
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
|
def preload_models(
|
||||||
|
text_use_gpu=True,
|
||||||
|
text_use_small=False,
|
||||||
|
coarse_use_gpu=True,
|
||||||
|
coarse_use_small=False,
|
||||||
|
fine_use_gpu=True,
|
||||||
|
fine_use_small=False,
|
||||||
|
codec_use_gpu=True,
|
||||||
|
force_reload=False,
|
||||||
|
):
|
||||||
_ = load_model(
|
_ = load_model(
|
||||||
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
|
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||||
)
|
)
|
||||||
_ = load_model(
|
_ = load_model(
|
||||||
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
|
model_type="coarse",
|
||||||
|
use_gpu=coarse_use_gpu,
|
||||||
|
use_small=coarse_use_small,
|
||||||
|
force_reload=force_reload,
|
||||||
)
|
)
|
||||||
_ = load_model(
|
_ = load_model(
|
||||||
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
|
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||||
)
|
)
|
||||||
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
|
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||||
|
|
||||||
|
|
||||||
####
|
####
|
||||||
|
|||||||
Reference in New Issue
Block a user