diff --git a/bark/generation.py b/bark/generation.py index 67e1932..790307b 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -14,6 +14,7 @@ import torch import torch.nn.functional as F import tqdm from transformers import BertTokenizer +from huggingface_hub import hf_hub_download from .model import GPTConfig, GPT from .model_fine import FineGPT, FineGPTConfig @@ -36,6 +37,9 @@ else: global models models = {} +global models_devices +models_devices = {} + CONTEXT_WINDOW_SIZE = 1024 @@ -61,41 +65,48 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) +GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) +OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) -REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" REMOTE_MODEL_PATHS = { "text_small": { - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "repo_id": "suno/bark", + "file_name": "text.pt", "checksum": "b3e42bcbab23b688355cd44128c4cdd3", }, "coarse_small": { - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "repo_id": "suno/bark", + "file_name": "coarse.pt", "checksum": "5fe964825e3b0321f9d5f3857b89194d", }, "fine_small": { - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "repo_id": "suno/bark", + "file_name": "fine.pt", "checksum": "5428d1befe05be2ba32195496e58dc90", }, "text": { - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "repo_id": "suno/bark", + "file_name": "text_2.pt", "checksum": "54afa89d65e318d4f5f80e8e8799026a", }, "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "repo_id": "suno/bark", + "file_name": "coarse_2.pt", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "repo_id": "suno/bark", + "file_name": "fine_2.pt", "checksum": "59d184ed44e3650774a2f0503a48a97b", }, } -if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): +if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available(): logger.warning( - "torch version does not support flash attention. You will get significantly faster" + - " inference speed by upgrade torch to newest version / nightly." + "torch version does not support flash attention. You will get faster" + + " inference speed by upgrade torch to newest nightly version." ) @@ -115,33 +126,25 @@ def _md5(fname): def _get_ckpt_path(model_type, use_small=False): model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type - model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) + model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"]) return os.path.join(CACHE_DIR, f"{model_name}.pt") -S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS: + device = "mps" + else: + device = "cpu" + return device -def _parse_s3_filepath(s3_filepath): - bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1) - rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath) - return bucket_name, rel_s3_filepath - - -def _download(from_s3_path, to_local_path): +def _download(from_hf_path, file_name, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) - response = requests.get(from_s3_path, stream=True) - total_size_in_bytes = int(response.headers.get("content-length", 0)) - block_size = 1024 - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) - with open(to_local_path, "wb") as file: - for data in response.iter_content(block_size): - progress_bar.update(len(data)) - file.write(data) - progress_bar.close() - if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: - raise ValueError("ERROR, something went wrong") - + destination_file_name = to_local_path.split("/")[-1] + hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR) + os.replace(os.path.join(CACHE_DIR, file_name), to_local_path) class InferenceContext: def __init__(self, benchmark=False): @@ -185,8 +188,6 @@ def clean_models(model_key=None): def _load_model(ckpt_path, device, use_small=False, model_type="text"): - if "cuda" not in device: - logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": ConfigClass = GPTConfig ModelClass = GPT @@ -208,7 +209,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") - _download(model_info["path"], ckpt_path) + _download(model_info["repo_id"], model_info["file_name"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] @@ -263,30 +264,40 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models - if torch.cuda.device_count() == 0 or not use_gpu: + global models_devices + device = _grab_best_device(use_gpu=use_gpu) + model_key = f"{model_type}" + if OFFLOAD_CPU: + models_devices[model_key] = device device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__{model_type}" if model_key not in models or force_reload: ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model + if model_type == "text": + models[model_key]["model"].to(device) + else: + models[model_key].to(device) return models[model_key] def load_codec_model(use_gpu=True, force_reload=False): global models - if torch.cuda.device_count() == 0 or not use_gpu: + global models_devices + device = _grab_best_device(use_gpu=use_gpu) + if device == "mps": + # encodec doesn't support mps + device = "cpu" + model_key = "codec" + if OFFLOAD_CPU: + models_devices[model_key] = device device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__codec" if model_key not in models or force_reload: clean_models(model_key=model_key) model = _load_codec_model(device) models[model_key] = model + models[model_key].to(device) return models[model_key] @@ -300,6 +311,11 @@ def preload_models( codec_use_gpu=True, force_reload=False, ): + """Load all the necessary models for the pipeline.""" + if _grab_best_device() == "cpu" and ( + text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu + ): + logger.warning("No GPU being used. Careful, inference might be very slow!") _ = load_model( model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload ) @@ -344,13 +360,11 @@ def generate_text_semantic( temp=0.7, top_k=None, top_p=None, - use_gpu=True, silent=False, min_eos_p=0.2, max_gen_duration_s=None, allow_early_stop=True, - model=None, - use_kv_caching=False + use_kv_caching=False, ): """Generate semantic tokens from text.""" assert isinstance(text, str) @@ -372,12 +386,18 @@ def generate_text_semantic( ) else: semantic_history = None - model_container = load_model(use_gpu=use_gpu, model_type="text") - if model is None: - model = model_container["model"] + # load models if not yet exist + global models + global models_devices + if "text" not in models: + preload_models() + model_container = models["text"] + model = model_container["model"] tokenizer = model_container["tokenizer"] encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + if OFFLOAD_CPU: + model.to(models_devices["text"]) + device = next(model.parameters()).device if len(encoded_text) > 256: p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) logger.warning(f"warning, text too long, lopping of last {p}%") @@ -401,7 +421,9 @@ def generate_text_semantic( else: semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256) x = torch.from_numpy( - np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64) + np.hstack([ + encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN]) + ]).astype(np.int64) )[None] assert x.shape[1] == 256 + 256 + 1 with _inference_mode(): @@ -417,8 +439,9 @@ def generate_text_semantic( x_input = x[:, [-1]] else: x_input = x - - logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) + logits, kv_cache = model( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -442,7 +465,13 @@ def generate_text_semantic( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") item_next = torch.multinomial(probs, num_samples=1) + probs = probs.to(inf_device) + item_next = item_next.to(inf_device) if allow_early_stop and ( item_next == SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) @@ -465,6 +494,8 @@ def generate_text_semantic( pbar_state = req_pbar_state pbar.close() out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :] + if OFFLOAD_CPU: + model.to("cpu") assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE) _clear_cuda_cache() return out @@ -490,12 +521,10 @@ def generate_coarse( temp=0.7, top_k=None, top_p=None, - use_gpu=True, silent=False, max_coarse_history=630, # min 60 (faster), max 630 (more context) sliding_window_len=60, - model=None, - use_kv_caching=False + use_kv_caching=False, ): """Generate coarse audio codes from semantic tokens.""" assert ( @@ -552,9 +581,15 @@ def generate_coarse( else: x_semantic_history = np.array([], dtype=np.int32) x_coarse_history = np.array([], dtype=np.int32) - if model is None: - model = load_model(use_gpu=use_gpu, model_type="coarse") - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + global models_devices + if "coarse" not in models: + preload_models() + model = models["coarse"] + if OFFLOAD_CPU: + model.to(models_devices["coarse"]) + device = next(model.parameters()).device # start loop n_steps = int( round( @@ -626,7 +661,13 @@ def generate_coarse( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") item_next = torch.multinomial(probs, num_samples=1) + probs = probs.to(inf_device) + item_next = item_next.to(inf_device) item_next += logit_start_idx x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) x_in = torch.cat((x_in, item_next[None]), dim=1) @@ -634,6 +675,8 @@ def generate_coarse( n_step += 1 del x_in del x_semantic_in + if OFFLOAD_CPU: + model.to("cpu") gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :] del x_coarse_in assert len(gen_coarse_arr) == n_steps @@ -648,9 +691,7 @@ def generate_fine( x_coarse_gen, history_prompt=None, temp=0.5, - use_gpu=True, silent=True, - model=None, ): """Generate full audio codes from coarse audio codes.""" assert ( @@ -679,9 +720,15 @@ def generate_fine( else: x_fine_history = None n_coarse = x_coarse_gen.shape[0] - if model is None: - model = load_model(use_gpu=use_gpu, model_type="fine") - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + global models_devices + if "fine" not in models: + preload_models() + model = models["fine"] + if OFFLOAD_CPU: + model.to(models_devices["fine"]) + device = next(model.parameters()).device # make input arr in_arr = np.vstack( [ @@ -729,10 +776,14 @@ def generate_fine( else: relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp probs = F.softmax(relevant_logits, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") codebook_preds = torch.hstack( [ - torch.multinomial(probs[n], num_samples=1) - for n in range(rel_start_fill_idx, 1024) + torch.multinomial(probs[nnn], num_samples=1).to(inf_device) + for nnn in range(rel_start_fill_idx, 1024) ] ) in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds @@ -745,6 +796,8 @@ def generate_fine( del in_buffer gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T del in_arr + if OFFLOAD_CPU: + model.to("cpu") gen_fine_arr = gen_fine_arr[:, n_history:] if n_remove_from_end > 0: gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end] @@ -753,11 +806,17 @@ def generate_fine( return gen_fine_arr -def codec_decode(fine_tokens, model=None, use_gpu=True): +def codec_decode(fine_tokens): """Turn quantized audio codes into audio array using encodec.""" - if model is None: - model = load_codec_model(use_gpu=use_gpu) - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + global models_devices + if "codec" not in models: + preload_models() + model = models["codec"] + if OFFLOAD_CPU: + model.to(models_devices["codec"]) + device = next(model.parameters()).device arr = torch.from_numpy(fine_tokens)[None] arr = arr.to(device) arr = arr.transpose(0, 1) @@ -765,4 +824,6 @@ def codec_decode(fine_tokens, model=None, use_gpu=True): out = model.decoder(emb) audio_arr = out.detach().cpu().numpy().squeeze() del arr, emb, out + if OFFLOAD_CPU: + model.to("cpu") return audio_arr diff --git a/bark/model.py b/bark/model.py index bb99932..457b49e 100644 --- a/bark/model.py +++ b/bark/model.py @@ -200,7 +200,6 @@ class GPT(nn.Module): pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) new_kv = () if use_cache else None