Fix to address pytorch Tensor.type() having missing support for MPS.

See https://github.com/pytorch/pytorch/issues/78929
This commit is contained in:
Raf Gemmail
2023-05-04 11:39:31 +12:00
parent b21d8c6c2c
commit 89b362b4a6

View File

@@ -663,7 +663,11 @@ def generate_coarse(
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
if GLOBAL_ENABLE_MPS:
relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float)
else:
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")