mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
simplify device placement
This commit is contained in:
@@ -83,6 +83,7 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
|
||||
|
||||
|
||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
||||
|
||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||
|
||||
@@ -114,10 +115,10 @@ REMOTE_MODEL_PATHS = {
|
||||
}
|
||||
|
||||
|
||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
|
||||
logger.warning(
|
||||
"torch version does not support flash attention. You will get significantly faster" +
|
||||
" inference speed by upgrade torch to newest version / nightly."
|
||||
"torch version does not support flash attention. You will get faster" +
|
||||
" inference speed by upgrade torch to newest nightly version."
|
||||
)
|
||||
|
||||
|
||||
@@ -141,6 +142,16 @@ def _get_ckpt_path(model_type, use_small=False):
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
|
||||
|
||||
def _grab_best_device(use_gpu=True):
|
||||
if torch.cuda.device_count() > 0 and use_gpu:
|
||||
device = "cuda"
|
||||
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
|
||||
device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
return device
|
||||
|
||||
|
||||
S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/"
|
||||
|
||||
|
||||
@@ -207,8 +218,6 @@ def clean_models(model_key=None):
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
@@ -285,30 +294,32 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
global models
|
||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__{model_type}"
|
||||
device = _grab_best_device(use_gpu=use_gpu)
|
||||
model_key = f"{model_type}"
|
||||
if model_key not in models or force_reload:
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
if model_type == "text":
|
||||
models[model_key]["model"].to(device)
|
||||
else:
|
||||
models[model_key].to(device)
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def load_codec_model(use_gpu=True, force_reload=False):
|
||||
global models
|
||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
||||
device = _grab_best_device(use_gpu=use_gpu)
|
||||
if device == "mps":
|
||||
# encodec doesn't support mps
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__codec"
|
||||
model_key = "codec"
|
||||
if model_key not in models or force_reload:
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_codec_model(device)
|
||||
models[model_key] = model
|
||||
models[model_key].to(device)
|
||||
return models[model_key]
|
||||
|
||||
|
||||
@@ -322,6 +333,11 @@ def preload_models(
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
):
|
||||
"""Load all the necessary models for the pipeline."""
|
||||
if _grab_best_device() == "cpu" and (
|
||||
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
|
||||
):
|
||||
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||
_ = load_model(
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
)
|
||||
@@ -366,13 +382,11 @@ def generate_text_semantic(
|
||||
temp=0.7,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
use_gpu=True,
|
||||
silent=False,
|
||||
min_eos_p=0.2,
|
||||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
use_kv_caching=False,
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
@@ -395,12 +409,15 @@ def generate_text_semantic(
|
||||
)
|
||||
else:
|
||||
semantic_history = None
|
||||
model_container = load_model(use_gpu=use_gpu, model_type="text")
|
||||
if model is None:
|
||||
model = model_container["model"]
|
||||
# load models if not yet exist
|
||||
global models
|
||||
if "text" not in models:
|
||||
preload_models()
|
||||
model_container = models["text"]
|
||||
model = model_container["model"]
|
||||
tokenizer = model_container["tokenizer"]
|
||||
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
device = next(model.parameters()).device
|
||||
if len(encoded_text) > 256:
|
||||
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
||||
logger.warning(f"warning, text too long, lopping of last {p}%")
|
||||
@@ -424,7 +441,9 @@ def generate_text_semantic(
|
||||
else:
|
||||
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
|
||||
x = torch.from_numpy(
|
||||
np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64)
|
||||
np.hstack([
|
||||
encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
|
||||
]).astype(np.int64)
|
||||
)[None]
|
||||
assert x.shape[1] == 256 + 256 + 1
|
||||
with _inference_mode():
|
||||
@@ -440,8 +459,9 @@ def generate_text_semantic(
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
|
||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
logits, kv_cache = model(
|
||||
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
|
||||
)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
@@ -465,7 +485,13 @@ def generate_text_semantic(
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.to(inf_device)
|
||||
item_next = item_next.to(inf_device)
|
||||
if allow_early_stop and (
|
||||
item_next == SEMANTIC_VOCAB_SIZE
|
||||
or (min_eos_p is not None and probs[-1] >= min_eos_p)
|
||||
@@ -513,12 +539,10 @@ def generate_coarse(
|
||||
temp=0.7,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
use_gpu=True,
|
||||
silent=False,
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
use_kv_caching=False,
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
@@ -576,9 +600,12 @@ def generate_coarse(
|
||||
else:
|
||||
x_semantic_history = np.array([], dtype=np.int32)
|
||||
x_coarse_history = np.array([], dtype=np.int32)
|
||||
if model is None:
|
||||
model = load_model(use_gpu=use_gpu, model_type="coarse")
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
# load models if not yet exist
|
||||
global models
|
||||
if "coarse" not in models:
|
||||
preload_models()
|
||||
model = models["coarse"]
|
||||
device = next(model.parameters()).device
|
||||
# start loop
|
||||
n_steps = int(
|
||||
round(
|
||||
@@ -650,7 +677,13 @@ def generate_coarse(
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
probs = probs.to(inf_device)
|
||||
item_next = item_next.to(inf_device)
|
||||
item_next += logit_start_idx
|
||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||
@@ -672,9 +705,7 @@ def generate_fine(
|
||||
x_coarse_gen,
|
||||
history_prompt=None,
|
||||
temp=0.5,
|
||||
use_gpu=True,
|
||||
silent=True,
|
||||
model=None,
|
||||
):
|
||||
"""Generate full audio codes from coarse audio codes."""
|
||||
assert (
|
||||
@@ -704,9 +735,12 @@ def generate_fine(
|
||||
else:
|
||||
x_fine_history = None
|
||||
n_coarse = x_coarse_gen.shape[0]
|
||||
if model is None:
|
||||
model = load_model(use_gpu=use_gpu, model_type="fine")
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
# load models if not yet exist
|
||||
global models
|
||||
if "fine" not in models:
|
||||
preload_models()
|
||||
model = models["fine"]
|
||||
device = next(model.parameters()).device
|
||||
# make input arr
|
||||
in_arr = np.vstack(
|
||||
[
|
||||
@@ -754,10 +788,14 @@ def generate_fine(
|
||||
else:
|
||||
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
|
||||
probs = F.softmax(relevant_logits, dim=-1)
|
||||
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||
inf_device = probs.device
|
||||
if probs.device.type == "mps":
|
||||
probs = probs.to("cpu")
|
||||
codebook_preds = torch.hstack(
|
||||
[
|
||||
torch.multinomial(probs[n], num_samples=1)
|
||||
for n in range(rel_start_fill_idx, 1024)
|
||||
torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
|
||||
for nnn in range(rel_start_fill_idx, 1024)
|
||||
]
|
||||
)
|
||||
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
|
||||
@@ -778,11 +816,14 @@ def generate_fine(
|
||||
return gen_fine_arr
|
||||
|
||||
|
||||
def codec_decode(fine_tokens, model=None, use_gpu=True):
|
||||
def codec_decode(fine_tokens):
|
||||
"""Turn quantized audio codes into audio array using encodec."""
|
||||
if model is None:
|
||||
model = load_codec_model(use_gpu=use_gpu)
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
# load models if not yet exist
|
||||
global models
|
||||
if "codec" not in models:
|
||||
preload_models()
|
||||
model = models["codec"]
|
||||
device = next(model.parameters()).device
|
||||
arr = torch.from_numpy(fine_tokens)[None]
|
||||
arr = arr.to(device)
|
||||
arr = arr.transpose(0, 1)
|
||||
|
||||
@@ -200,7 +200,6 @@ class GPT(nn.Module):
|
||||
|
||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
|
||||
new_kv = () if use_cache else None
|
||||
|
||||
Reference in New Issue
Block a user