mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
preserve device type in a temporary variable
This commit is contained in:
@@ -451,8 +451,7 @@ def generate_text_semantic(
|
|||||||
)
|
)
|
||||||
if top_p is not None:
|
if top_p is not None:
|
||||||
# faster to convert to numpy
|
# faster to convert to numpy
|
||||||
logits_device = relevant_logits.device
|
original_device = relevant_logits.device
|
||||||
logits_dtype = relevant_logits.type()
|
|
||||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||||
sorted_logits = relevant_logits[sorted_indices]
|
sorted_logits = relevant_logits[sorted_indices]
|
||||||
@@ -462,7 +461,7 @@ def generate_text_semantic(
|
|||||||
sorted_indices_to_remove[0] = False
|
sorted_indices_to_remove[0] = False
|
||||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||||
relevant_logits = torch.from_numpy(relevant_logits)
|
relevant_logits = torch.from_numpy(relevant_logits)
|
||||||
relevant_logits = relevant_logits.to(logits_device)
|
relevant_logits = relevant_logits.to(original_device)
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||||
@@ -647,8 +646,7 @@ def generate_coarse(
|
|||||||
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
||||||
if top_p is not None:
|
if top_p is not None:
|
||||||
# faster to convert to numpy
|
# faster to convert to numpy
|
||||||
logits_device = relevant_logits.device
|
original_device = relevant_logits.device
|
||||||
logits_dtype = relevant_logits.type()
|
|
||||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||||
sorted_logits = relevant_logits[sorted_indices]
|
sorted_logits = relevant_logits[sorted_indices]
|
||||||
@@ -658,7 +656,7 @@ def generate_coarse(
|
|||||||
sorted_indices_to_remove[0] = False
|
sorted_indices_to_remove[0] = False
|
||||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||||
relevant_logits = torch.from_numpy(relevant_logits)
|
relevant_logits = torch.from_numpy(relevant_logits)
|
||||||
relevant_logits = relevant_logits.to(logits_device)
|
relevant_logits = relevant_logits.to(original_device)
|
||||||
if top_k is not None:
|
if top_k is not None:
|
||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||||
|
|||||||
Reference in New Issue
Block a user