From 89b362b4a6c92364eec47f22642fc39b7d9ede1c Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Thu, 4 May 2023 11:39:31 +1200 Subject: [PATCH] Fix to address pytorch Tensor.type() having missing support for MPS. See https://github.com/pytorch/pytorch/issues/78929 --- bark/generation.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/bark/generation.py b/bark/generation.py index b7fbb9e..f93e897 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -663,7 +663,11 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if GLOBAL_ENABLE_MPS: + relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float) + else: + relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf")