mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 19:27:57 +01:00
DRY: refactored mps supporting dtype.
This commit is contained in:
@@ -335,11 +335,20 @@ def preload_models(
|
|||||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||||
|
|
||||||
|
|
||||||
|
####
|
||||||
|
# Handle MPS immaturity in Pytorch
|
||||||
|
####
|
||||||
|
def _logits_to_device_float(logits):
|
||||||
|
if GLOBAL_ENABLE_MPS:
|
||||||
|
logits = logits.clone().detach().to("mps").to(torch.float)
|
||||||
|
else:
|
||||||
|
logits = logits.to(logits_device).type(logits_dtype)
|
||||||
|
return logits
|
||||||
|
|
||||||
####
|
####
|
||||||
# Generation Functionality
|
# Generation Functionality
|
||||||
####
|
####
|
||||||
|
|
||||||
|
|
||||||
def _tokenize(tokenizer, text):
|
def _tokenize(tokenizer, text):
|
||||||
return tokenizer.encode(text, add_special_tokens=False)
|
return tokenizer.encode(text, add_special_tokens=False)
|
||||||
|
|
||||||
@@ -351,7 +360,6 @@ def _detokenize(tokenizer, enc_text):
|
|||||||
def _normalize_whitespace(text):
|
def _normalize_whitespace(text):
|
||||||
return re.sub(r"\s+", " ", text).strip()
|
return re.sub(r"\s+", " ", text).strip()
|
||||||
|
|
||||||
|
|
||||||
TEXT_ENCODING_OFFSET = 10_048
|
TEXT_ENCODING_OFFSET = 10_048
|
||||||
SEMANTIC_PAD_TOKEN = 10_000
|
SEMANTIC_PAD_TOKEN = 10_000
|
||||||
TEXT_PAD_TOKEN = 129_595
|
TEXT_PAD_TOKEN = 129_595
|
||||||
@@ -464,10 +472,8 @@ def generate_text_semantic(
|
|||||||
sorted_indices_to_remove[0] = False
|
sorted_indices_to_remove[0] = False
|
||||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||||
relevant_logits = torch.from_numpy(relevant_logits)
|
relevant_logits = torch.from_numpy(relevant_logits)
|
||||||
if GLOBAL_ENABLE_MPS:
|
relevant_logits = _logits_to_device_float(relevant_logits)
|
||||||
relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float)
|
|
||||||
else:
|
|
||||||
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||||
@@ -663,10 +669,7 @@ def generate_coarse(
|
|||||||
sorted_indices_to_remove[0] = False
|
sorted_indices_to_remove[0] = False
|
||||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||||
relevant_logits = torch.from_numpy(relevant_logits)
|
relevant_logits = torch.from_numpy(relevant_logits)
|
||||||
if GLOBAL_ENABLE_MPS:
|
relevant_logits = _logits_to_device_float(relevant_logits)
|
||||||
relevant_logits = torch.tensor(relevant_logits, device="mps").to(torch.float)
|
|
||||||
else:
|
|
||||||
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
|
|
||||||
|
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
|
|||||||
Reference in New Issue
Block a user