mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 11:48:09 +01:00
Merge branch 'suno-ai:main' into main
This commit is contained in:
@@ -14,6 +14,7 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import tqdm
|
import tqdm
|
||||||
from transformers import BertTokenizer
|
from transformers import BertTokenizer
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
from .model import GPTConfig, GPT
|
from .model import GPTConfig, GPT
|
||||||
from .model_fine import FineGPT, FineGPTConfig
|
from .model_fine import FineGPT, FineGPTConfig
|
||||||
@@ -36,6 +37,9 @@ else:
|
|||||||
global models
|
global models
|
||||||
models = {}
|
models = {}
|
||||||
|
|
||||||
|
global models_devices
|
||||||
|
models_devices = {}
|
||||||
|
|
||||||
|
|
||||||
CONTEXT_WINDOW_SIZE = 1024
|
CONTEXT_WINDOW_SIZE = 1024
|
||||||
|
|
||||||
@@ -61,41 +65,48 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno",
|
|||||||
|
|
||||||
|
|
||||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||||
|
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
||||||
|
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
|
||||||
|
|
||||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
|
||||||
|
|
||||||
REMOTE_MODEL_PATHS = {
|
REMOTE_MODEL_PATHS = {
|
||||||
"text_small": {
|
"text_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "text.pt",
|
||||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||||
},
|
},
|
||||||
"coarse_small": {
|
"coarse_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "coarse.pt",
|
||||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||||
},
|
},
|
||||||
"fine_small": {
|
"fine_small": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "fine.pt",
|
||||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||||
},
|
},
|
||||||
"text": {
|
"text": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "text_2.pt",
|
||||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||||
},
|
},
|
||||||
"coarse": {
|
"coarse": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "coarse_2.pt",
|
||||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||||
},
|
},
|
||||||
"fine": {
|
"fine": {
|
||||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
"repo_id": "suno/bark",
|
||||||
|
"file_name": "fine_2.pt",
|
||||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'):
|
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"torch version does not support flash attention. You will get significantly faster" +
|
"torch version does not support flash attention. You will get faster" +
|
||||||
" inference speed by upgrade torch to newest version / nightly."
|
" inference speed by upgrade torch to newest nightly version."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -115,33 +126,25 @@ def _md5(fname):
|
|||||||
|
|
||||||
def _get_ckpt_path(model_type, use_small=False):
|
def _get_ckpt_path(model_type, use_small=False):
|
||||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
|
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["file_name"])
|
||||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||||
|
|
||||||
|
|
||||||
S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/"
|
def _grab_best_device(use_gpu=True):
|
||||||
|
if torch.cuda.device_count() > 0 and use_gpu:
|
||||||
|
device = "cuda"
|
||||||
|
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
|
||||||
|
device = "mps"
|
||||||
|
else:
|
||||||
|
device = "cpu"
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
def _parse_s3_filepath(s3_filepath):
|
def _download(from_hf_path, file_name, to_local_path):
|
||||||
bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1)
|
|
||||||
rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath)
|
|
||||||
return bucket_name, rel_s3_filepath
|
|
||||||
|
|
||||||
|
|
||||||
def _download(from_s3_path, to_local_path):
|
|
||||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||||
response = requests.get(from_s3_path, stream=True)
|
destination_file_name = to_local_path.split("/")[-1]
|
||||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
|
||||||
block_size = 1024
|
os.replace(os.path.join(CACHE_DIR, file_name), to_local_path)
|
||||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
|
||||||
with open(to_local_path, "wb") as file:
|
|
||||||
for data in response.iter_content(block_size):
|
|
||||||
progress_bar.update(len(data))
|
|
||||||
file.write(data)
|
|
||||||
progress_bar.close()
|
|
||||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
|
||||||
raise ValueError("ERROR, something went wrong")
|
|
||||||
|
|
||||||
|
|
||||||
class InferenceContext:
|
class InferenceContext:
|
||||||
def __init__(self, benchmark=False):
|
def __init__(self, benchmark=False):
|
||||||
@@ -185,8 +188,6 @@ def clean_models(model_key=None):
|
|||||||
|
|
||||||
|
|
||||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||||
if "cuda" not in device:
|
|
||||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
|
||||||
if model_type == "text":
|
if model_type == "text":
|
||||||
ConfigClass = GPTConfig
|
ConfigClass = GPTConfig
|
||||||
ModelClass = GPT
|
ModelClass = GPT
|
||||||
@@ -208,7 +209,7 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
|||||||
os.remove(ckpt_path)
|
os.remove(ckpt_path)
|
||||||
if not os.path.exists(ckpt_path):
|
if not os.path.exists(ckpt_path):
|
||||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||||
_download(model_info["path"], ckpt_path)
|
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
|
||||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||||
# this is a hack
|
# this is a hack
|
||||||
model_args = checkpoint["model_args"]
|
model_args = checkpoint["model_args"]
|
||||||
@@ -263,30 +264,40 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
|
|||||||
if model_type not in ("text", "coarse", "fine"):
|
if model_type not in ("text", "coarse", "fine"):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
global models
|
global models
|
||||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
global models_devices
|
||||||
|
device = _grab_best_device(use_gpu=use_gpu)
|
||||||
|
model_key = f"{model_type}"
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
models_devices[model_key] = device
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
else:
|
|
||||||
device = "cuda"
|
|
||||||
model_key = str(device) + f"__{model_type}"
|
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
model = _load_model_f(ckpt_path, device)
|
model = _load_model_f(ckpt_path, device)
|
||||||
models[model_key] = model
|
models[model_key] = model
|
||||||
|
if model_type == "text":
|
||||||
|
models[model_key]["model"].to(device)
|
||||||
|
else:
|
||||||
|
models[model_key].to(device)
|
||||||
return models[model_key]
|
return models[model_key]
|
||||||
|
|
||||||
|
|
||||||
def load_codec_model(use_gpu=True, force_reload=False):
|
def load_codec_model(use_gpu=True, force_reload=False):
|
||||||
global models
|
global models
|
||||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
global models_devices
|
||||||
|
device = _grab_best_device(use_gpu=use_gpu)
|
||||||
|
if device == "mps":
|
||||||
|
# encodec doesn't support mps
|
||||||
|
device = "cpu"
|
||||||
|
model_key = "codec"
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
models_devices[model_key] = device
|
||||||
device = "cpu"
|
device = "cpu"
|
||||||
else:
|
|
||||||
device = "cuda"
|
|
||||||
model_key = str(device) + f"__codec"
|
|
||||||
if model_key not in models or force_reload:
|
if model_key not in models or force_reload:
|
||||||
clean_models(model_key=model_key)
|
clean_models(model_key=model_key)
|
||||||
model = _load_codec_model(device)
|
model = _load_codec_model(device)
|
||||||
models[model_key] = model
|
models[model_key] = model
|
||||||
|
models[model_key].to(device)
|
||||||
return models[model_key]
|
return models[model_key]
|
||||||
|
|
||||||
|
|
||||||
@@ -300,6 +311,11 @@ def preload_models(
|
|||||||
codec_use_gpu=True,
|
codec_use_gpu=True,
|
||||||
force_reload=False,
|
force_reload=False,
|
||||||
):
|
):
|
||||||
|
"""Load all the necessary models for the pipeline."""
|
||||||
|
if _grab_best_device() == "cpu" and (
|
||||||
|
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
|
||||||
|
):
|
||||||
|
logger.warning("No GPU being used. Careful, inference might be very slow!")
|
||||||
_ = load_model(
|
_ = load_model(
|
||||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||||
)
|
)
|
||||||
@@ -344,13 +360,11 @@ def generate_text_semantic(
|
|||||||
temp=0.7,
|
temp=0.7,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
use_gpu=True,
|
|
||||||
silent=False,
|
silent=False,
|
||||||
min_eos_p=0.2,
|
min_eos_p=0.2,
|
||||||
max_gen_duration_s=None,
|
max_gen_duration_s=None,
|
||||||
allow_early_stop=True,
|
allow_early_stop=True,
|
||||||
model=None,
|
use_kv_caching=False,
|
||||||
use_kv_caching=False
|
|
||||||
):
|
):
|
||||||
"""Generate semantic tokens from text."""
|
"""Generate semantic tokens from text."""
|
||||||
assert isinstance(text, str)
|
assert isinstance(text, str)
|
||||||
@@ -372,12 +386,18 @@ def generate_text_semantic(
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
semantic_history = None
|
semantic_history = None
|
||||||
model_container = load_model(use_gpu=use_gpu, model_type="text")
|
# load models if not yet exist
|
||||||
if model is None:
|
global models
|
||||||
|
global models_devices
|
||||||
|
if "text" not in models:
|
||||||
|
preload_models()
|
||||||
|
model_container = models["text"]
|
||||||
model = model_container["model"]
|
model = model_container["model"]
|
||||||
tokenizer = model_container["tokenizer"]
|
tokenizer = model_container["tokenizer"]
|
||||||
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
||||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["text"])
|
||||||
|
device = next(model.parameters()).device
|
||||||
if len(encoded_text) > 256:
|
if len(encoded_text) > 256:
|
||||||
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
||||||
logger.warning(f"warning, text too long, lopping of last {p}%")
|
logger.warning(f"warning, text too long, lopping of last {p}%")
|
||||||
@@ -401,7 +421,9 @@ def generate_text_semantic(
|
|||||||
else:
|
else:
|
||||||
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
|
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
|
||||||
x = torch.from_numpy(
|
x = torch.from_numpy(
|
||||||
np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64)
|
np.hstack([
|
||||||
|
encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
|
||||||
|
]).astype(np.int64)
|
||||||
)[None]
|
)[None]
|
||||||
assert x.shape[1] == 256 + 256 + 1
|
assert x.shape[1] == 256 + 256 + 1
|
||||||
with _inference_mode():
|
with _inference_mode():
|
||||||
@@ -417,8 +439,9 @@ def generate_text_semantic(
|
|||||||
x_input = x[:, [-1]]
|
x_input = x[:, [-1]]
|
||||||
else:
|
else:
|
||||||
x_input = x
|
x_input = x
|
||||||
|
logits, kv_cache = model(
|
||||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
|
||||||
|
)
|
||||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||||
if allow_early_stop:
|
if allow_early_stop:
|
||||||
relevant_logits = torch.hstack(
|
relevant_logits = torch.hstack(
|
||||||
@@ -442,7 +465,13 @@ def generate_text_semantic(
|
|||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||||
|
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||||
|
inf_device = probs.device
|
||||||
|
if probs.device.type == "mps":
|
||||||
|
probs = probs.to("cpu")
|
||||||
item_next = torch.multinomial(probs, num_samples=1)
|
item_next = torch.multinomial(probs, num_samples=1)
|
||||||
|
probs = probs.to(inf_device)
|
||||||
|
item_next = item_next.to(inf_device)
|
||||||
if allow_early_stop and (
|
if allow_early_stop and (
|
||||||
item_next == SEMANTIC_VOCAB_SIZE
|
item_next == SEMANTIC_VOCAB_SIZE
|
||||||
or (min_eos_p is not None and probs[-1] >= min_eos_p)
|
or (min_eos_p is not None and probs[-1] >= min_eos_p)
|
||||||
@@ -465,6 +494,8 @@ def generate_text_semantic(
|
|||||||
pbar_state = req_pbar_state
|
pbar_state = req_pbar_state
|
||||||
pbar.close()
|
pbar.close()
|
||||||
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
||||||
_clear_cuda_cache()
|
_clear_cuda_cache()
|
||||||
return out
|
return out
|
||||||
@@ -490,12 +521,10 @@ def generate_coarse(
|
|||||||
temp=0.7,
|
temp=0.7,
|
||||||
top_k=None,
|
top_k=None,
|
||||||
top_p=None,
|
top_p=None,
|
||||||
use_gpu=True,
|
|
||||||
silent=False,
|
silent=False,
|
||||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||||
sliding_window_len=60,
|
sliding_window_len=60,
|
||||||
model=None,
|
use_kv_caching=False,
|
||||||
use_kv_caching=False
|
|
||||||
):
|
):
|
||||||
"""Generate coarse audio codes from semantic tokens."""
|
"""Generate coarse audio codes from semantic tokens."""
|
||||||
assert (
|
assert (
|
||||||
@@ -552,9 +581,15 @@ def generate_coarse(
|
|||||||
else:
|
else:
|
||||||
x_semantic_history = np.array([], dtype=np.int32)
|
x_semantic_history = np.array([], dtype=np.int32)
|
||||||
x_coarse_history = np.array([], dtype=np.int32)
|
x_coarse_history = np.array([], dtype=np.int32)
|
||||||
if model is None:
|
# load models if not yet exist
|
||||||
model = load_model(use_gpu=use_gpu, model_type="coarse")
|
global models
|
||||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
global models_devices
|
||||||
|
if "coarse" not in models:
|
||||||
|
preload_models()
|
||||||
|
model = models["coarse"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["coarse"])
|
||||||
|
device = next(model.parameters()).device
|
||||||
# start loop
|
# start loop
|
||||||
n_steps = int(
|
n_steps = int(
|
||||||
round(
|
round(
|
||||||
@@ -626,7 +661,13 @@ def generate_coarse(
|
|||||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||||
|
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||||
|
inf_device = probs.device
|
||||||
|
if probs.device.type == "mps":
|
||||||
|
probs = probs.to("cpu")
|
||||||
item_next = torch.multinomial(probs, num_samples=1)
|
item_next = torch.multinomial(probs, num_samples=1)
|
||||||
|
probs = probs.to(inf_device)
|
||||||
|
item_next = item_next.to(inf_device)
|
||||||
item_next += logit_start_idx
|
item_next += logit_start_idx
|
||||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||||
@@ -634,6 +675,8 @@ def generate_coarse(
|
|||||||
n_step += 1
|
n_step += 1
|
||||||
del x_in
|
del x_in
|
||||||
del x_semantic_in
|
del x_semantic_in
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
||||||
del x_coarse_in
|
del x_coarse_in
|
||||||
assert len(gen_coarse_arr) == n_steps
|
assert len(gen_coarse_arr) == n_steps
|
||||||
@@ -648,9 +691,7 @@ def generate_fine(
|
|||||||
x_coarse_gen,
|
x_coarse_gen,
|
||||||
history_prompt=None,
|
history_prompt=None,
|
||||||
temp=0.5,
|
temp=0.5,
|
||||||
use_gpu=True,
|
|
||||||
silent=True,
|
silent=True,
|
||||||
model=None,
|
|
||||||
):
|
):
|
||||||
"""Generate full audio codes from coarse audio codes."""
|
"""Generate full audio codes from coarse audio codes."""
|
||||||
assert (
|
assert (
|
||||||
@@ -679,9 +720,15 @@ def generate_fine(
|
|||||||
else:
|
else:
|
||||||
x_fine_history = None
|
x_fine_history = None
|
||||||
n_coarse = x_coarse_gen.shape[0]
|
n_coarse = x_coarse_gen.shape[0]
|
||||||
if model is None:
|
# load models if not yet exist
|
||||||
model = load_model(use_gpu=use_gpu, model_type="fine")
|
global models
|
||||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
global models_devices
|
||||||
|
if "fine" not in models:
|
||||||
|
preload_models()
|
||||||
|
model = models["fine"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["fine"])
|
||||||
|
device = next(model.parameters()).device
|
||||||
# make input arr
|
# make input arr
|
||||||
in_arr = np.vstack(
|
in_arr = np.vstack(
|
||||||
[
|
[
|
||||||
@@ -729,10 +776,14 @@ def generate_fine(
|
|||||||
else:
|
else:
|
||||||
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
|
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
|
||||||
probs = F.softmax(relevant_logits, dim=-1)
|
probs = F.softmax(relevant_logits, dim=-1)
|
||||||
|
# multinomial bugged on mps: shuttle to cpu if necessary
|
||||||
|
inf_device = probs.device
|
||||||
|
if probs.device.type == "mps":
|
||||||
|
probs = probs.to("cpu")
|
||||||
codebook_preds = torch.hstack(
|
codebook_preds = torch.hstack(
|
||||||
[
|
[
|
||||||
torch.multinomial(probs[n], num_samples=1)
|
torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
|
||||||
for n in range(rel_start_fill_idx, 1024)
|
for nnn in range(rel_start_fill_idx, 1024)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
|
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
|
||||||
@@ -745,6 +796,8 @@ def generate_fine(
|
|||||||
del in_buffer
|
del in_buffer
|
||||||
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
||||||
del in_arr
|
del in_arr
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
gen_fine_arr = gen_fine_arr[:, n_history:]
|
gen_fine_arr = gen_fine_arr[:, n_history:]
|
||||||
if n_remove_from_end > 0:
|
if n_remove_from_end > 0:
|
||||||
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
||||||
@@ -753,11 +806,17 @@ def generate_fine(
|
|||||||
return gen_fine_arr
|
return gen_fine_arr
|
||||||
|
|
||||||
|
|
||||||
def codec_decode(fine_tokens, model=None, use_gpu=True):
|
def codec_decode(fine_tokens):
|
||||||
"""Turn quantized audio codes into audio array using encodec."""
|
"""Turn quantized audio codes into audio array using encodec."""
|
||||||
if model is None:
|
# load models if not yet exist
|
||||||
model = load_codec_model(use_gpu=use_gpu)
|
global models
|
||||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
global models_devices
|
||||||
|
if "codec" not in models:
|
||||||
|
preload_models()
|
||||||
|
model = models["codec"]
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to(models_devices["codec"])
|
||||||
|
device = next(model.parameters()).device
|
||||||
arr = torch.from_numpy(fine_tokens)[None]
|
arr = torch.from_numpy(fine_tokens)[None]
|
||||||
arr = arr.to(device)
|
arr = arr.to(device)
|
||||||
arr = arr.transpose(0, 1)
|
arr = arr.transpose(0, 1)
|
||||||
@@ -765,4 +824,6 @@ def codec_decode(fine_tokens, model=None, use_gpu=True):
|
|||||||
out = model.decoder(emb)
|
out = model.decoder(emb)
|
||||||
audio_arr = out.detach().cpu().numpy().squeeze()
|
audio_arr = out.detach().cpu().numpy().squeeze()
|
||||||
del arr, emb, out
|
del arr, emb, out
|
||||||
|
if OFFLOAD_CPU:
|
||||||
|
model.to("cpu")
|
||||||
return audio_arr
|
return audio_arr
|
||||||
|
|||||||
@@ -200,7 +200,6 @@ class GPT(nn.Module):
|
|||||||
|
|
||||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||||
|
|
||||||
|
|
||||||
x = self.transformer.drop(tok_emb + pos_emb)
|
x = self.transformer.drop(tok_emb + pos_emb)
|
||||||
|
|
||||||
new_kv = () if use_cache else None
|
new_kv = () if use_cache else None
|
||||||
|
|||||||
Reference in New Issue
Block a user