Merge branch 'suno-ai:main' into main

This commit is contained in:
Francis LaBounty
2023-04-29 16:27:34 -06:00
committed by GitHub
2 changed files with 130 additions and 70 deletions

View File

@@ -14,6 +14,7 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm import tqdm
from transformers import BertTokenizer from transformers import BertTokenizer
from huggingface_hub import hf_hub_download
from .model import GPTConfig, GPT from .model import GPTConfig, GPT
from .model_fine import FineGPT, FineGPTConfig from .model_fine import FineGPT, FineGPTConfig
@@ -36,6 +37,9 @@ else:
global models global models
models = {} models = {}
global models_devices
models_devices = {}
CONTEXT_WINDOW_SIZE = 1024 CONTEXT_WINDOW_SIZE = 1024
@@ -61,41 +65,48 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
REMOTE_MODEL_PATHS = { REMOTE_MODEL_PATHS = {
"text_small": { "text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"), "repo_id": "suno/bark",
"file_name": "text.pt",
"checksum": "b3e42bcbab23b688355cd44128c4cdd3", "checksum": "b3e42bcbab23b688355cd44128c4cdd3",
}, },
"coarse_small": { "coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), "repo_id": "suno/bark",
"file_name": "coarse.pt",
"checksum": "5fe964825e3b0321f9d5f3857b89194d", "checksum": "5fe964825e3b0321f9d5f3857b89194d",
}, },
"fine_small": { "fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"), "repo_id": "suno/bark",
"file_name": "fine.pt",
"checksum": "5428d1befe05be2ba32195496e58dc90", "checksum": "5428d1befe05be2ba32195496e58dc90",
}, },
"text": { "text": {
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), "repo_id": "suno/bark",
"file_name": "text_2.pt",
"checksum": "54afa89d65e318d4f5f80e8e8799026a", "checksum": "54afa89d65e318d4f5f80e8e8799026a",
}, },
"coarse": { "coarse": {
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), "repo_id": "suno/bark",
"file_name": "coarse_2.pt",
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
}, },
"fine": { "fine": {
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), "repo_id": "suno/bark",
"file_name": "fine_2.pt",
"checksum": "59d184ed44e3650774a2f0503a48a97b", "checksum": "59d184ed44e3650774a2f0503a48a97b",
}, },
} }
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
logger.warning( logger.warning(
"torch version does not support flash attention. You will get significantly faster" + "torch version does not support flash attention. You will get faster" +
" inference speed by upgrade torch to newest version / nightly." " inference speed by upgrade torch to newest nightly version."
) )
@@ -115,33 +126,25 @@ def _md5(fname):
def _get_ckpt_path(model_type, use_small=False): def _get_ckpt_path(model_type, use_small=False):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
return os.path.join(CACHE_DIR, f"{model_name}.pt") return os.path.join(CACHE_DIR, f"{model_name}.pt")
S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
device = "mps"
else:
device = "cpu"
return device
def _parse_s3_filepath(s3_filepath): def _download(from_hf_path, file_name, to_local_path):
bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1)
rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath)
return bucket_name, rel_s3_filepath
def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True) destination_file_name = to_local_path.split("/")[-1]
total_size_in_bytes = int(response.headers.get("content-length", 0)) hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
block_size = 1024 os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(to_local_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
raise ValueError("ERROR, something went wrong")
class InferenceContext: class InferenceContext:
def __init__(self, benchmark=False): def __init__(self, benchmark=False):
@@ -185,8 +188,6 @@ def clean_models(model_key=None):
def _load_model(ckpt_path, device, use_small=False, model_type="text"): def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if "cuda" not in device:
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text": if model_type == "text":
ConfigClass = GPTConfig ConfigClass = GPTConfig
ModelClass = GPT ModelClass = GPT
@@ -208,7 +209,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
os.remove(ckpt_path) os.remove(ckpt_path)
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["path"], ckpt_path) _download(model_info["repo_id"], model_info["file_name"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]
@@ -263,30 +264,40 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
if model_type not in ("text", "coarse", "fine"): if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError() raise NotImplementedError()
global models global models
if torch.cuda.device_count() == 0 or not use_gpu: global models_devices
device = _grab_best_device(use_gpu=use_gpu)
model_key = f"{model_type}"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu" device = "cpu"
else:
device = "cuda"
model_key = str(device) + f"__{model_type}"
if model_key not in models or force_reload: if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small) ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
clean_models(model_key=model_key) clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device) model = _load_model_f(ckpt_path, device)
models[model_key] = model models[model_key] = model
if model_type == "text":
models[model_key]["model"].to(device)
else:
models[model_key].to(device)
return models[model_key] return models[model_key]
def load_codec_model(use_gpu=True, force_reload=False): def load_codec_model(use_gpu=True, force_reload=False):
global models global models
if torch.cuda.device_count() == 0 or not use_gpu: global models_devices
device = _grab_best_device(use_gpu=use_gpu)
if device == "mps":
# encodec doesn't support mps
device = "cpu"
model_key = "codec"
if OFFLOAD_CPU:
models_devices[model_key] = device
device = "cpu" device = "cpu"
else:
device = "cuda"
model_key = str(device) + f"__codec"
if model_key not in models or force_reload: if model_key not in models or force_reload:
clean_models(model_key=model_key) clean_models(model_key=model_key)
model = _load_codec_model(device) model = _load_codec_model(device)
models[model_key] = model models[model_key] = model
models[model_key].to(device)
return models[model_key] return models[model_key]
@@ -300,6 +311,11 @@ def preload_models(
codec_use_gpu=True, codec_use_gpu=True,
force_reload=False, force_reload=False,
): ):
"""Load all the necessary models for the pipeline."""
if _grab_best_device() == "cpu" and (
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model( _ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
) )
@@ -344,13 +360,11 @@ def generate_text_semantic(
temp=0.7, temp=0.7,
top_k=None, top_k=None,
top_p=None, top_p=None,
use_gpu=True,
silent=False, silent=False,
min_eos_p=0.2, min_eos_p=0.2,
max_gen_duration_s=None, max_gen_duration_s=None,
allow_early_stop=True, allow_early_stop=True,
model=None, use_kv_caching=False,
use_kv_caching=False
): ):
"""Generate semantic tokens from text.""" """Generate semantic tokens from text."""
assert isinstance(text, str) assert isinstance(text, str)
@@ -372,12 +386,18 @@ def generate_text_semantic(
) )
else: else:
semantic_history = None semantic_history = None
model_container = load_model(use_gpu=use_gpu, model_type="text") # load models if not yet exist
if model is None: global models
model = model_container["model"] global models_devices
if "text" not in models:
preload_models()
model_container = models["text"]
model = model_container["model"]
tokenizer = model_container["tokenizer"] tokenizer = model_container["tokenizer"]
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" if OFFLOAD_CPU:
model.to(models_devices["text"])
device = next(model.parameters()).device
if len(encoded_text) > 256: if len(encoded_text) > 256:
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
logger.warning(f"warning, text too long, lopping of last {p}%") logger.warning(f"warning, text too long, lopping of last {p}%")
@@ -401,7 +421,9 @@ def generate_text_semantic(
else: else:
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256) semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
x = torch.from_numpy( x = torch.from_numpy(
np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64) np.hstack([
encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
]).astype(np.int64)
)[None] )[None]
assert x.shape[1] == 256 + 256 + 1 assert x.shape[1] == 256 + 256 + 1
with _inference_mode(): with _inference_mode():
@@ -417,8 +439,9 @@ def generate_text_semantic(
x_input = x[:, [-1]] x_input = x[:, [-1]]
else: else:
x_input = x x_input = x
logits, kv_cache = model(
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop: if allow_early_stop:
relevant_logits = torch.hstack( relevant_logits = torch.hstack(
@@ -442,7 +465,13 @@ def generate_text_semantic(
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf") relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1) probs = F.softmax(relevant_logits / temp, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
item_next = torch.multinomial(probs, num_samples=1) item_next = torch.multinomial(probs, num_samples=1)
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
if allow_early_stop and ( if allow_early_stop and (
item_next == SEMANTIC_VOCAB_SIZE item_next == SEMANTIC_VOCAB_SIZE
or (min_eos_p is not None and probs[-1] >= min_eos_p) or (min_eos_p is not None and probs[-1] >= min_eos_p)
@@ -465,6 +494,8 @@ def generate_text_semantic(
pbar_state = req_pbar_state pbar_state = req_pbar_state
pbar.close() pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
if OFFLOAD_CPU:
model.to("cpu")
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE) assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
_clear_cuda_cache() _clear_cuda_cache()
return out return out
@@ -490,12 +521,10 @@ def generate_coarse(
temp=0.7, temp=0.7,
top_k=None, top_k=None,
top_p=None, top_p=None,
use_gpu=True,
silent=False, silent=False,
max_coarse_history=630, # min 60 (faster), max 630 (more context) max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60, sliding_window_len=60,
model=None, use_kv_caching=False,
use_kv_caching=False
): ):
"""Generate coarse audio codes from semantic tokens.""" """Generate coarse audio codes from semantic tokens."""
assert ( assert (
@@ -552,9 +581,15 @@ def generate_coarse(
else: else:
x_semantic_history = np.array([], dtype=np.int32) x_semantic_history = np.array([], dtype=np.int32)
x_coarse_history = np.array([], dtype=np.int32) x_coarse_history = np.array([], dtype=np.int32)
if model is None: # load models if not yet exist
model = load_model(use_gpu=use_gpu, model_type="coarse") global models
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" global models_devices
if "coarse" not in models:
preload_models()
model = models["coarse"]
if OFFLOAD_CPU:
model.to(models_devices["coarse"])
device = next(model.parameters()).device
# start loop # start loop
n_steps = int( n_steps = int(
round( round(
@@ -626,7 +661,13 @@ def generate_coarse(
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf") relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1) probs = F.softmax(relevant_logits / temp, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
item_next = torch.multinomial(probs, num_samples=1) item_next = torch.multinomial(probs, num_samples=1)
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
item_next += logit_start_idx item_next += logit_start_idx
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
x_in = torch.cat((x_in, item_next[None]), dim=1) x_in = torch.cat((x_in, item_next[None]), dim=1)
@@ -634,6 +675,8 @@ def generate_coarse(
n_step += 1 n_step += 1
del x_in del x_in
del x_semantic_in del x_semantic_in
if OFFLOAD_CPU:
model.to("cpu")
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
del x_coarse_in del x_coarse_in
assert len(gen_coarse_arr) == n_steps assert len(gen_coarse_arr) == n_steps
@@ -648,9 +691,7 @@ def generate_fine(
x_coarse_gen, x_coarse_gen,
history_prompt=None, history_prompt=None,
temp=0.5, temp=0.5,
use_gpu=True,
silent=True, silent=True,
model=None,
): ):
"""Generate full audio codes from coarse audio codes.""" """Generate full audio codes from coarse audio codes."""
assert ( assert (
@@ -679,9 +720,15 @@ def generate_fine(
else: else:
x_fine_history = None x_fine_history = None
n_coarse = x_coarse_gen.shape[0] n_coarse = x_coarse_gen.shape[0]
if model is None: # load models if not yet exist
model = load_model(use_gpu=use_gpu, model_type="fine") global models
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" global models_devices
if "fine" not in models:
preload_models()
model = models["fine"]
if OFFLOAD_CPU:
model.to(models_devices["fine"])
device = next(model.parameters()).device
# make input arr # make input arr
in_arr = np.vstack( in_arr = np.vstack(
[ [
@@ -729,10 +776,14 @@ def generate_fine(
else: else:
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
probs = F.softmax(relevant_logits, dim=-1) probs = F.softmax(relevant_logits, dim=-1)
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
codebook_preds = torch.hstack( codebook_preds = torch.hstack(
[ [
torch.multinomial(probs[n], num_samples=1) torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
for n in range(rel_start_fill_idx, 1024) for nnn in range(rel_start_fill_idx, 1024)
] ]
) )
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
@@ -745,6 +796,8 @@ def generate_fine(
del in_buffer del in_buffer
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
del in_arr del in_arr
if OFFLOAD_CPU:
model.to("cpu")
gen_fine_arr = gen_fine_arr[:, n_history:] gen_fine_arr = gen_fine_arr[:, n_history:]
if n_remove_from_end > 0: if n_remove_from_end > 0:
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
@@ -753,11 +806,17 @@ def generate_fine(
return gen_fine_arr return gen_fine_arr
def codec_decode(fine_tokens, model=None, use_gpu=True): def codec_decode(fine_tokens):
"""Turn quantized audio codes into audio array using encodec.""" """Turn quantized audio codes into audio array using encodec."""
if model is None: # load models if not yet exist
model = load_codec_model(use_gpu=use_gpu) global models
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" global models_devices
if "codec" not in models:
preload_models()
model = models["codec"]
if OFFLOAD_CPU:
model.to(models_devices["codec"])
device = next(model.parameters()).device
arr = torch.from_numpy(fine_tokens)[None] arr = torch.from_numpy(fine_tokens)[None]
arr = arr.to(device) arr = arr.to(device)
arr = arr.transpose(0, 1) arr = arr.transpose(0, 1)
@@ -765,4 +824,6 @@ def codec_decode(fine_tokens, model=None, use_gpu=True):
out = model.decoder(emb) out = model.decoder(emb)
audio_arr = out.detach().cpu().numpy().squeeze() audio_arr = out.detach().cpu().numpy().squeeze()
del arr, emb, out del arr, emb, out
if OFFLOAD_CPU:
model.to("cpu")
return audio_arr return audio_arr

View File

@@ -200,7 +200,6 @@ class GPT(nn.Module):
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
new_kv = () if use_cache else None new_kv = () if use_cache else None