From 6c26fb7b3463334ae9fb4d63dae52f3c29506db0 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Tue, 25 Apr 2023 17:49:35 -0400 Subject: [PATCH] simplify device placement --- bark/generation.py | 123 ++++++++++++++++++++++++++++++--------------- bark/model.py | 1 - 2 files changed, 82 insertions(+), 42 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 4aa805e..28d963c 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -83,6 +83,7 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) +GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" @@ -114,10 +115,10 @@ REMOTE_MODEL_PATHS = { } -if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): +if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available(): logger.warning( - "torch version does not support flash attention. You will get significantly faster" + - " inference speed by upgrade torch to newest version / nightly." + "torch version does not support flash attention. You will get faster" + + " inference speed by upgrade torch to newest nightly version." ) @@ -141,6 +142,16 @@ def _get_ckpt_path(model_type, use_small=False): return os.path.join(CACHE_DIR, f"{model_name}.pt") +def _grab_best_device(use_gpu=True): + if torch.cuda.device_count() > 0 and use_gpu: + device = "cuda" + elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS: + device = "mps" + else: + device = "cpu" + return device + + S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/" @@ -207,8 +218,6 @@ def clean_models(model_key=None): def _load_model(ckpt_path, device, use_small=False, model_type="text"): - if "cuda" not in device: - logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": ConfigClass = GPTConfig ModelClass = GPT @@ -285,30 +294,32 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models - if torch.cuda.device_count() == 0 or not use_gpu: - device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__{model_type}" + device = _grab_best_device(use_gpu=use_gpu) + model_key = f"{model_type}" if model_key not in models or force_reload: ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model + if model_type == "text": + models[model_key]["model"].to(device) + else: + models[model_key].to(device) return models[model_key] def load_codec_model(use_gpu=True, force_reload=False): global models - if torch.cuda.device_count() == 0 or not use_gpu: + device = _grab_best_device(use_gpu=use_gpu) + if device == "mps": + # encodec doesn't support mps device = "cpu" - else: - device = "cuda" - model_key = str(device) + f"__codec" + model_key = "codec" if model_key not in models or force_reload: clean_models(model_key=model_key) model = _load_codec_model(device) models[model_key] = model + models[model_key].to(device) return models[model_key] @@ -322,6 +333,11 @@ def preload_models( codec_use_gpu=True, force_reload=False, ): + """Load all the necessary models for the pipeline.""" + if _grab_best_device() == "cpu" and ( + text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu + ): + logger.warning("No GPU being used. Careful, inference might be very slow!") _ = load_model( model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload ) @@ -366,13 +382,11 @@ def generate_text_semantic( temp=0.7, top_k=None, top_p=None, - use_gpu=True, silent=False, min_eos_p=0.2, max_gen_duration_s=None, allow_early_stop=True, - model=None, - use_kv_caching=False + use_kv_caching=False, ): """Generate semantic tokens from text.""" assert isinstance(text, str) @@ -395,12 +409,15 @@ def generate_text_semantic( ) else: semantic_history = None - model_container = load_model(use_gpu=use_gpu, model_type="text") - if model is None: - model = model_container["model"] + # load models if not yet exist + global models + if "text" not in models: + preload_models() + model_container = models["text"] + model = model_container["model"] tokenizer = model_container["tokenizer"] encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + device = next(model.parameters()).device if len(encoded_text) > 256: p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1) logger.warning(f"warning, text too long, lopping of last {p}%") @@ -424,7 +441,9 @@ def generate_text_semantic( else: semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256) x = torch.from_numpy( - np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64) + np.hstack([ + encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN]) + ]).astype(np.int64) )[None] assert x.shape[1] == 256 + 256 + 1 with _inference_mode(): @@ -440,8 +459,9 @@ def generate_text_semantic( x_input = x[:, [-1]] else: x_input = x - - logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) + logits, kv_cache = model( + x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache + ) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -465,7 +485,13 @@ def generate_text_semantic( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") item_next = torch.multinomial(probs, num_samples=1) + probs = probs.to(inf_device) + item_next = item_next.to(inf_device) if allow_early_stop and ( item_next == SEMANTIC_VOCAB_SIZE or (min_eos_p is not None and probs[-1] >= min_eos_p) @@ -513,12 +539,10 @@ def generate_coarse( temp=0.7, top_k=None, top_p=None, - use_gpu=True, silent=False, max_coarse_history=630, # min 60 (faster), max 630 (more context) sliding_window_len=60, - model=None, - use_kv_caching=False + use_kv_caching=False, ): """Generate coarse audio codes from semantic tokens.""" assert ( @@ -576,9 +600,12 @@ def generate_coarse( else: x_semantic_history = np.array([], dtype=np.int32) x_coarse_history = np.array([], dtype=np.int32) - if model is None: - model = load_model(use_gpu=use_gpu, model_type="coarse") - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + if "coarse" not in models: + preload_models() + model = models["coarse"] + device = next(model.parameters()).device # start loop n_steps = int( round( @@ -650,7 +677,13 @@ def generate_coarse( v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") probs = F.softmax(relevant_logits / temp, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") item_next = torch.multinomial(probs, num_samples=1) + probs = probs.to(inf_device) + item_next = item_next.to(inf_device) item_next += logit_start_idx x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1) x_in = torch.cat((x_in, item_next[None]), dim=1) @@ -672,9 +705,7 @@ def generate_fine( x_coarse_gen, history_prompt=None, temp=0.5, - use_gpu=True, silent=True, - model=None, ): """Generate full audio codes from coarse audio codes.""" assert ( @@ -704,9 +735,12 @@ def generate_fine( else: x_fine_history = None n_coarse = x_coarse_gen.shape[0] - if model is None: - model = load_model(use_gpu=use_gpu, model_type="fine") - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + if "fine" not in models: + preload_models() + model = models["fine"] + device = next(model.parameters()).device # make input arr in_arr = np.vstack( [ @@ -754,10 +788,14 @@ def generate_fine( else: relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp probs = F.softmax(relevant_logits, dim=-1) + # multinomial bugged on mps: shuttle to cpu if necessary + inf_device = probs.device + if probs.device.type == "mps": + probs = probs.to("cpu") codebook_preds = torch.hstack( [ - torch.multinomial(probs[n], num_samples=1) - for n in range(rel_start_fill_idx, 1024) + torch.multinomial(probs[nnn], num_samples=1).to(inf_device) + for nnn in range(rel_start_fill_idx, 1024) ] ) in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds @@ -778,11 +816,14 @@ def generate_fine( return gen_fine_arr -def codec_decode(fine_tokens, model=None, use_gpu=True): +def codec_decode(fine_tokens): """Turn quantized audio codes into audio array using encodec.""" - if model is None: - model = load_codec_model(use_gpu=use_gpu) - device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu" + # load models if not yet exist + global models + if "codec" not in models: + preload_models() + model = models["codec"] + device = next(model.parameters()).device arr = torch.from_numpy(fine_tokens)[None] arr = arr.to(device) arr = arr.transpose(0, 1) diff --git a/bark/model.py b/bark/model.py index bb99932..457b49e 100644 --- a/bark/model.py +++ b/bark/model.py @@ -200,7 +200,6 @@ class GPT(nn.Module): pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) - x = self.transformer.drop(tok_emb + pos_emb) new_kv = () if use_cache else None