diff --git a/bark/generation.py b/bark/generation.py index d1c7a6b..c9d4317 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -339,7 +339,6 @@ def preload_models( # Generation Functionality #### - def _tokenize(tokenizer, text): return tokenizer.encode(text, add_special_tokens=False) @@ -351,7 +350,6 @@ def _detokenize(tokenizer, enc_text): def _normalize_whitespace(text): return re.sub(r"\s+", " ", text).strip() - TEXT_ENCODING_OFFSET = 10_048 SEMANTIC_PAD_TOKEN = 10_000 TEXT_PAD_TOKEN = 129_595 @@ -453,8 +451,7 @@ def generate_text_semantic( ) if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -464,7 +461,7 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -649,8 +646,7 @@ def generate_coarse( relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -660,7 +656,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device).type(logits_dtype) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf")