From b21d8c6c2ccc7680cc07eb241dec4df84128c43c Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Thu, 4 May 2023 10:05:21 +1200 Subject: [PATCH 1/5] Force type as pytorch Tensor.type() has missing support for MPS. See https://github.com/pytorch/pytorch/issues/78929 --- bark/generation.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bark/generation.py b/bark/generation.py index d1c7a6b..b7fbb9e 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -464,7 +464,10 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if GLOBAL_ENABLE_MPS: + relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) + else: + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") From 89b362b4a6c92364eec47f22642fc39b7d9ede1c Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Thu, 4 May 2023 11:39:31 +1200 Subject: [PATCH 2/5] Fix to address pytorch Tensor.type() having missing support for MPS. See https://github.com/pytorch/pytorch/issues/78929 --- bark/generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bark/generation.py b/bark/generation.py index b7fbb9e..f93e897 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -663,7 +663,11 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if GLOBAL_ENABLE_MPS: + relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) + else: + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") From c95350f13f10b32cd7969cfa6c6f8c5bfc6ec69a Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Thu, 4 May 2023 12:57:48 +1200 Subject: [PATCH 3/5] DRY: refactored mps supporting dtype. --- bark/generation.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index f93e897..b273e17 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -335,11 +335,20 @@ def preload_models( _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) +#### +# Handle MPS immaturity in Pytorch +#### +def _logits_to_device_float(logits): + if GLOBAL_ENABLE_MPS: + logits = logits.clone().detach().to("mps").to(torch.float) + else: + logits = logits.to(logits_device).type(logits_dtype) + return logits + #### # Generation Functionality #### - def _tokenize(tokenizer, text): return tokenizer.encode(text, add_special_tokens=False) @@ -351,7 +360,6 @@ def _detokenize(tokenizer, enc_text): def _normalize_whitespace(text): return re.sub(r"\s+", " ", text).strip() - TEXT_ENCODING_OFFSET = 10_048 SEMANTIC_PAD_TOKEN = 10_000 TEXT_PAD_TOKEN = 129_595 @@ -464,10 +472,8 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - if GLOBAL_ENABLE_MPS: - relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) - else: - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = _logits_to_device_float(relevant_logits) + if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -663,10 +669,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - if GLOBAL_ENABLE_MPS: - relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) - else: - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = _logits_to_device_float(relevant_logits) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) From 8660b7d7bbc2b3e63014ee62cca8773b301a7411 Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Sun, 7 May 2023 16:10:33 +1200 Subject: [PATCH 4/5] simplified fix for pytorch MPS .type() issue --- bark/generation.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index b273e17..0f25ad4 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -335,16 +335,6 @@ def preload_models( _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) -#### -# Handle MPS immaturity in Pytorch -#### -def _logits_to_device_float(logits): - if GLOBAL_ENABLE_MPS: - logits = logits.clone().detach().to("mps").to(torch.float) - else: - logits = logits.to(logits_device).type(logits_dtype) - return logits - #### # Generation Functionality #### @@ -472,8 +462,7 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = _logits_to_device_float(relevant_logits) - + relevant_logits = relevant_logits.to(logits_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -669,8 +658,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = _logits_to_device_float(relevant_logits) - + relevant_logits = relevant_logits.to(logits_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") From c84f71e187f2563096741bdea1a3b29f22e7cba0 Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Sat, 13 May 2023 14:33:43 +1200 Subject: [PATCH 5/5] preserve device type in a temporary variable --- bark/generation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 0f25ad4..c9d4317 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -451,8 +451,7 @@ def generate_text_semantic( ) if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -462,7 +461,7 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -647,8 +646,7 @@ def generate_coarse( relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -658,7 +656,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf")