From c95350f13f10b32cd7969cfa6c6f8c5bfc6ec69a Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Thu, 4 May 2023 12:57:48 +1200 Subject: [PATCH] DRY: refactored mps supporting dtype. --- bark/generation.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index f93e897..b273e17 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -335,11 +335,20 @@ def preload_models( _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) +#### +# Handle MPS immaturity in Pytorch +#### +def _logits_to_device_float(logits): + if GLOBAL_ENABLE_MPS: + logits = logits.clone().detach().to("mps").to(torch.float) + else: + logits = logits.to(logits_device).type(logits_dtype) + return logits + #### # Generation Functionality #### - def _tokenize(tokenizer, text): return tokenizer.encode(text, add_special_tokens=False) @@ -351,7 +360,6 @@ def _detokenize(tokenizer, enc_text): def _normalize_whitespace(text): return re.sub(r"\s+", " ", text).strip() - TEXT_ENCODING_OFFSET = 10_048 SEMANTIC_PAD_TOKEN = 10_000 TEXT_PAD_TOKEN = 129_595 @@ -464,10 +472,8 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - if GLOBAL_ENABLE_MPS: - relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) - else: - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = _logits_to_device_float(relevant_logits) + if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -663,10 +669,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - if GLOBAL_ENABLE_MPS: - relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) - else: - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = _logits_to_device_float(relevant_logits) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))