From c84f71e187f2563096741bdea1a3b29f22e7cba0 Mon Sep 17 00:00:00 2001 From: Raf Gemmail Date: Sat, 13 May 2023 14:33:43 +1200 Subject: [PATCH] preserve device type in a temporary variable --- bark/generation.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index 0f25ad4..c9d4317 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -451,8 +451,7 @@ def generate_text_semantic( ) if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -462,7 +461,7 @@ def generate_text_semantic( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf") @@ -647,8 +646,7 @@ def generate_coarse( relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx] if top_p is not None: # faster to convert to numpy - logits_device = relevant_logits.device - logits_dtype = relevant_logits.type() + original_device = relevant_logits.device relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy() sorted_indices = np.argsort(relevant_logits)[::-1] sorted_logits = relevant_logits[sorted_indices] @@ -658,7 +656,7 @@ def generate_coarse( sorted_indices_to_remove[0] = False relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf relevant_logits = torch.from_numpy(relevant_logits) - relevant_logits = relevant_logits.to(logits_device) + relevant_logits = relevant_logits.to(original_device) if top_k is not None: v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1))) relevant_logits[relevant_logits < v[-1]] = -float("Inf")