diff --git a/bark/generation.py b/bark/generation.py index b273e17..0f25ad4 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -335,16 +335,6 @@ def preload_models( _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) -#### -# Handle MPS immaturity in Pytorch -#### -def _logits_to_device_float(logits): - if GLOBAL_ENABLE_MPS: - logits = logits.clone().detach().to("mps").to(torch.float) - else: - logits = logits.to(logits_device).type(logits_dtype) - return logits - #### # Generation Functionality #### @@ -472,8 +462,7 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = _logits_to_device_float(relevant_logits) - + relevant_logits = relevant_logits.to(logits_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -669,8 +658,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = _logits_to_device_float(relevant_logits) - + relevant_logits = relevant_logits.to(logits_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf")