mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
simplified fix for pytorch MPS .type() issue
This commit is contained in:
@@ -335,16 +335,6 @@ def preload_models(
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
|
||||
|
||||
####
|
||||
# Handle MPS immaturity in Pytorch
|
||||
####
|
||||
def _logits_to_device_float(logits):
|
||||
if GLOBAL_ENABLE_MPS:
|
||||
logits = logits.clone().detach().to("mps").to(torch.float)
|
||||
else:
|
||||
logits = logits.to(logits_device).type(logits_dtype)
|
||||
return logits
|
||||
|
||||
####
|
||||
# Generation Functionality
|
||||
####
|
||||
@@ -472,8 +462,7 @@ def generate_text_semantic(
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits)
|
||||
relevant_logits = _logits_to_device_float(relevant_logits)
|
||||
|
||||
relevant_logits = relevant_logits.to(logits_device)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
@@ -669,8 +658,7 @@ def generate_coarse(
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits)
|
||||
relevant_logits = _logits_to_device_float(relevant_logits)
|
||||
|
||||
relevant_logits = relevant_logits.to(logits_device)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
|
||||
Reference in New Issue
Block a user