Files

886 lines
32 KiB
Python
Raw Permalink Normal View History

2023-04-09 13:21:02 -04:00
import contextlib
2023-04-22 17:09:20 -04:00
import gc
2023-04-09 13:21:02 -04:00
import hashlib
import os
import re
2023-06-29 21:48:18 -06:00
import json
2023-04-09 13:21:02 -04:00
from encodec import EncodecModel
import funcy
import logging
import numpy as np
from scipy.special import softmax
import torch
import torch.nn.functional as F
import tqdm
from transformers import BertTokenizer
2023-04-27 16:12:54 +02:00
from huggingface_hub import hf_hub_download
2023-04-09 13:21:02 -04:00
from .model import GPTConfig, GPT
from .model_fine import FineGPT, FineGPTConfig
if (
torch.cuda.is_available() and
hasattr(torch.cuda, "amp") and
hasattr(torch.cuda.amp, "autocast") and
2023-04-21 15:31:36 -04:00
hasattr(torch.cuda, "is_bf16_supported") and
2023-04-09 13:21:02 -04:00
torch.cuda.is_bf16_supported()
):
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
else:
@contextlib.contextmanager
def autocast():
yield
# hold models in global scope to lazy load
global models
models = {}
2023-04-25 21:21:52 -03:00
global models_devices
models_devices = {}
2023-04-09 13:21:02 -04:00
CONTEXT_WINDOW_SIZE = 1024
SEMANTIC_RATE_HZ = 49.9
SEMANTIC_VOCAB_SIZE = 10_000
CODEBOOK_SIZE = 1024
N_COARSE_CODEBOOKS = 2
N_FINE_CODEBOOKS = 8
COARSE_RATE_HZ = 75
SAMPLE_RATE = 24_000
logger = logging.getLogger(__name__)
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "serp", "bark_v0")
2023-04-09 13:21:02 -04:00
USE_SMALL_MODELS = os.environ.get("SERP_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SERP_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SERP_OFFLOAD_CPU", False)
2023-04-21 15:31:36 -04:00
2023-04-22 17:09:20 -04:00
2023-04-09 13:21:02 -04:00
REMOTE_MODEL_PATHS = {
2023-04-22 17:09:20 -04:00
"text_small": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "text.pt",
2023-04-22 17:09:20 -04:00
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse_small": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "coarse.pt",
2023-04-22 17:09:20 -04:00
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine_small": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "fine.pt",
2023-04-22 17:09:20 -04:00
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
2023-04-12 14:31:00 -04:00
"text": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "text_2.pt",
2023-04-12 14:31:00 -04:00
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "coarse_2.pt",
2023-04-12 14:31:00 -04:00
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"repo_id": "suno/bark",
2023-04-27 16:12:54 +02:00
"file_name": "fine_2.pt",
2023-04-12 14:31:00 -04:00
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
2023-04-09 13:21:02 -04:00
}
2023-04-25 17:49:35 -04:00
if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
2023-04-17 15:59:07 -04:00
logger.warning(
2023-04-25 17:49:35 -04:00
"torch version does not support flash attention. You will get faster" +
" inference speed by upgrade torch to newest nightly version."
2023-04-17 15:59:07 -04:00
)
2023-04-12 14:31:00 -04:00
def _string_md5(s):
2023-04-09 13:21:02 -04:00
m = hashlib.md5()
m.update(s.encode("utf-8"))
return m.hexdigest()
2023-04-12 14:31:00 -04:00
def _md5(fname):
hash_md5 = hashlib.md5()
with open(fname, "rb") as f:
for chunk in iter(lambda: f.read(4096), b""):
hash_md5.update(chunk)
return hash_md5.hexdigest()
2023-04-29 16:53:09 -06:00
def _get_ckpt_path(model_type, use_small=False, path=None):
2023-04-22 17:09:20 -04:00
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
2023-04-30 05:30:56 -06:00
model_name = REMOTE_MODEL_PATHS[model_key]["file_name"]
2023-04-29 16:53:09 -06:00
if path is None:
path = CACHE_DIR
2023-04-30 05:30:56 -06:00
return os.path.join(path, f"{model_name}")
2023-04-09 13:21:02 -04:00
2023-04-25 17:49:35 -04:00
def _grab_best_device(use_gpu=True):
if torch.cuda.device_count() > 0 and use_gpu:
device = "cuda"
elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
device = "mps"
else:
device = "cpu"
return device
2023-04-09 13:21:02 -04:00
2023-04-27 16:12:54 +02:00
def _download(from_hf_path, file_name, to_local_path):
2023-04-29 16:53:09 -06:00
to_local_path = to_local_path.replace("\\", "/")
path = '/'.join(to_local_path.split("/")[:-1])
os.makedirs(path, exist_ok=True)
hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)
os.replace(os.path.join(path, file_name), to_local_path)
2023-04-09 13:21:02 -04:00
class InferenceContext:
def __init__(self, benchmark=False):
# we can't expect inputs to be the same length, so disable benchmarking by default
self._chosen_cudnn_benchmark = benchmark
self._cudnn_benchmark = None
def __enter__(self):
self._cudnn_benchmark = torch.backends.cudnn.benchmark
torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
def __exit__(self, exc_type, exc_value, exc_traceback):
torch.backends.cudnn.benchmark = self._cudnn_benchmark
if torch.cuda.is_available():
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
@contextlib.contextmanager
def _inference_mode():
with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
yield
def _clear_cuda_cache():
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()
def clean_models(model_key=None):
global models
model_keys = [model_key] if model_key is not None else models.keys()
for k in model_keys:
if k in models:
del models[k]
_clear_cuda_cache()
2023-04-22 17:09:20 -04:00
gc.collect()
2023-04-09 13:21:02 -04:00
2023-04-22 17:09:20 -04:00
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
2023-04-09 13:21:02 -04:00
if model_type == "text":
ConfigClass = GPTConfig
ModelClass = GPT
elif model_type == "coarse":
ConfigClass = GPTConfig
ModelClass = GPT
elif model_type == "fine":
ConfigClass = FineGPTConfig
ModelClass = FineGPT
else:
raise NotImplementedError()
2023-04-22 17:09:20 -04:00
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
2023-06-29 21:48:18 -06:00
# if (
# os.path.exists(ckpt_path) and
# _md5(ckpt_path) != model_info["checksum"]
# ):
# logger.warning(f"found outdated {model_type} model, removing.")
# os.remove(ckpt_path)
2023-04-09 13:21:02 -04:00
if not os.path.exists(ckpt_path):
2023-04-21 15:13:16 -04:00
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
2023-04-27 16:12:54 +02:00
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
2023-04-09 13:21:02 -04:00
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
2023-06-29 21:48:18 -06:00
# check if config.json is in the same directory as the checkpoint
# if so, load it
# otherwise, assume it's in the checkpoint
config_path = os.path.join(os.path.dirname(ckpt_path), "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
model_args = json.load(f)
else:
model_args = checkpoint["model_args"]
2023-04-09 13:21:02 -04:00
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
2023-06-29 21:48:18 -06:00
gptconf = ConfigClass(**model_args)
2023-04-09 13:21:02 -04:00
model = ModelClass(gptconf)
2023-06-29 21:48:18 -06:00
if checkpoint.get("model", None) is not None:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
2023-04-09 13:21:02 -04:00
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
2023-06-29 21:48:18 -06:00
unwanted_suffixes = [
"lora_right_weight",
"lora_left_weight",
"lora_right_bias",
"lora_left_bias",
]
for k, v in list(state_dict.items()):
for suffix in unwanted_suffixes:
if k.endswith(suffix):
state_dict.pop(k)
# super hacky - should probably refactor this
if state_dict.get('lm_head.0.weight', None) is not None:
state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
if state_dict.get('lm_heads.0.0.weight', None) is not None:
state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
if state_dict.get('lm_heads.1.0.weight', None) is not None:
state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
if state_dict.get('lm_heads.2.0.weight', None) is not None:
state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
if state_dict.get('lm_heads.3.0.weight', None) is not None:
state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
if state_dict.get('lm_heads.4.0.weight', None) is not None:
state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
if state_dict.get('lm_heads.5.0.weight', None) is not None:
state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
if state_dict.get('lm_heads.6.0.weight', None) is not None:
state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
2023-04-09 13:21:02 -04:00
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
if len(extra_keys) != 0:
2023-06-29 21:48:18 -06:00
print(f"extra keys found: {extra_keys}")
2023-04-09 13:21:02 -04:00
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
2023-06-29 21:48:18 -06:00
if checkpoint.get("best_val_loss", None) is not None:
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
2023-04-09 13:21:02 -04:00
model.eval()
model.to(device)
del checkpoint, state_dict
_clear_cuda_cache()
if model_type == "text":
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
return {
"model": model,
"tokenizer": tokenizer,
}
return model
2023-04-30 05:30:56 -06:00
def _load_codec_model(device):
model = EncodecModel.encodec_model_24khz()
2023-04-09 13:21:02 -04:00
model.set_target_bandwidth(6.0)
model.eval()
model.to(device)
_clear_cuda_cache()
return model
2023-04-29 16:53:09 -06:00
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text", path=None):
2023-04-22 17:09:20 -04:00
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
2023-04-09 13:21:02 -04:00
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
device = _grab_best_device(use_gpu=use_gpu)
model_key = f"{model_type}"
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
models_devices[model_key] = device
2023-04-09 13:21:02 -04:00
device = "cpu"
if model_key not in models or force_reload:
2023-06-29 21:48:18 -06:00
if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"):
ckpt_path = path
else:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
# clean_models(model_key=model_key)
2023-04-09 13:21:02 -04:00
model = _load_model_f(ckpt_path, device)
models[model_key] = model
2023-04-25 17:49:35 -04:00
if model_type == "text":
models[model_key]["model"].to(device)
else:
models[model_key].to(device)
2023-04-09 13:21:02 -04:00
return models[model_key]
2023-04-30 05:30:56 -06:00
def load_codec_model(use_gpu=True, force_reload=False):
2023-04-09 13:21:02 -04:00
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
device = _grab_best_device(use_gpu=use_gpu)
if device == "mps":
# encodec doesn't support mps
2023-04-09 13:21:02 -04:00
device = "cpu"
2023-04-25 17:49:35 -04:00
model_key = "codec"
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
models_devices[model_key] = device
2023-04-09 13:21:02 -04:00
device = "cpu"
if model_key not in models or force_reload:
clean_models(model_key=model_key)
2023-04-30 05:30:56 -06:00
model = _load_codec_model(device)
2023-04-09 13:21:02 -04:00
models[model_key] = model
2023-04-25 17:49:35 -04:00
models[model_key].to(device)
2023-04-09 13:21:02 -04:00
return models[model_key]
2023-04-22 17:09:20 -04:00
def preload_models(
text_use_gpu=True,
text_use_small=False,
2023-06-29 21:48:18 -06:00
text_model_path=None,
2023-04-22 17:09:20 -04:00
coarse_use_gpu=True,
coarse_use_small=False,
2023-06-29 21:48:18 -06:00
coarse_model_path=None,
2023-04-22 17:09:20 -04:00
fine_use_gpu=True,
fine_use_small=False,
2023-06-29 21:48:18 -06:00
fine_model_path=None,
2023-04-22 17:09:20 -04:00
codec_use_gpu=True,
force_reload=False,
2023-04-29 16:53:09 -06:00
path=None,
2023-04-22 17:09:20 -04:00
):
2023-04-25 17:49:35 -04:00
"""Load all the necessary models for the pipeline."""
if _grab_best_device() == "cpu" and (
text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
2023-04-09 13:21:02 -04:00
_ = load_model(
2023-06-29 21:48:18 -06:00
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path
2023-04-09 13:21:02 -04:00
)
_ = load_model(
2023-04-22 17:09:20 -04:00
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
2023-06-29 21:48:18 -06:00
path=path if coarse_model_path is None else coarse_model_path,
2023-04-09 13:21:02 -04:00
)
_ = load_model(
2023-06-29 21:48:18 -06:00
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path
2023-04-09 13:21:02 -04:00
)
2023-04-30 05:30:56 -06:00
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
2023-04-09 13:21:02 -04:00
2023-05-04 12:57:48 +12:00
####
# Generation Functionality
####
2023-04-09 13:21:02 -04:00
def _tokenize(tokenizer, text):
return tokenizer.encode(text, add_special_tokens=False)
def _detokenize(tokenizer, enc_text):
return tokenizer.decode(enc_text)
def _normalize_whitespace(text):
return re.sub(r"\s+", " ", text).strip()
TEXT_ENCODING_OFFSET = 10_048
SEMANTIC_PAD_TOKEN = 10_000
TEXT_PAD_TOKEN = 129_595
SEMANTIC_INFER_TOKEN = 129_599
def generate_text_semantic(
text,
history_prompt=None,
temp=0.7,
top_k=None,
top_p=None,
silent=False,
min_eos_p=0.2,
max_gen_duration_s=None,
allow_early_stop=True,
2023-04-25 17:49:35 -04:00
use_kv_caching=False,
2023-04-09 13:21:02 -04:00
):
"""Generate semantic tokens from text."""
assert isinstance(text, str)
text = _normalize_whitespace(text)
assert len(text.strip()) > 0
if history_prompt is not None:
2023-04-21 16:14:10 -04:00
if history_prompt.endswith(".npz"):
2023-07-19 19:12:27 -06:00
try:
semantic_history = np.load(history_prompt)["semantic_prompt"]
except:
semantic_history = np.load(history_prompt)["semantic"]
2023-04-21 16:14:10 -04:00
else:
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
2023-04-09 13:21:02 -04:00
assert (
isinstance(semantic_history, np.ndarray)
and len(semantic_history.shape) == 1
and len(semantic_history) > 0
and semantic_history.min() >= 0
and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
)
else:
semantic_history = None
2023-04-25 17:49:35 -04:00
# load models if not yet exist
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
if "text" not in models:
preload_models()
model_container = models["text"]
model = model_container["model"]
2023-04-09 13:21:02 -04:00
tokenizer = model_container["tokenizer"]
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to(models_devices["text"])
2023-04-25 17:49:35 -04:00
device = next(model.parameters()).device
2023-04-09 13:21:02 -04:00
if len(encoded_text) > 256:
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
2023-04-13 14:27:35 -04:00
logger.warning(f"warning, text too long, lopping of last {p}%")
2023-04-09 13:21:02 -04:00
encoded_text = encoded_text[:256]
encoded_text = np.pad(
encoded_text,
(0, 256 - len(encoded_text)),
constant_values=TEXT_PAD_TOKEN,
mode="constant",
)
if semantic_history is not None:
semantic_history = semantic_history.astype(np.int64)
# lop off if history is too long, pad if needed
semantic_history = semantic_history[-256:]
semantic_history = np.pad(
semantic_history,
(0, 256 - len(semantic_history)),
constant_values=SEMANTIC_PAD_TOKEN,
mode="constant",
)
else:
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
x = torch.from_numpy(
2023-04-25 17:49:35 -04:00
np.hstack([
encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
]).astype(np.int64)
2023-04-09 13:21:02 -04:00
)[None]
assert x.shape[1] == 256 + 256 + 1
with _inference_mode():
x = x.to(device)
n_tot_steps = 768
# custom tqdm updates since we don't know when eos will occur
pbar = tqdm.tqdm(disable=silent, total=100)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
2023-04-09 13:21:02 -04:00
for n in range(n_tot_steps):
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
2023-04-25 17:49:35 -04:00
logits, kv_cache = model(
x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
)
2023-04-09 13:21:02 -04:00
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop:
relevant_logits = torch.hstack(
(relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
)
if top_p is not None:
# faster to convert to numpy
original_device = relevant_logits.device
2023-04-09 13:21:02 -04:00
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
cumulative_probs = np.cumsum(softmax(sorted_logits))
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(original_device)
2023-04-09 13:21:02 -04:00
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
2023-04-25 17:49:35 -04:00
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
2023-04-09 13:21:02 -04:00
item_next = torch.multinomial(probs, num_samples=1)
2023-04-25 17:49:35 -04:00
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
2023-04-09 13:21:02 -04:00
if allow_early_stop and (
item_next == SEMANTIC_VOCAB_SIZE
or (min_eos_p is not None and probs[-1] >= min_eos_p)
):
# eos found, so break
pbar.update(100 - pbar_state)
break
x = torch.cat((x, item_next[None]), dim=1)
tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
pbar.update(100 - pbar_state)
break
if n == n_tot_steps - 1:
pbar.update(100 - pbar_state)
break
del logits, relevant_logits, probs, item_next
req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))])
if req_pbar_state > pbar_state:
pbar.update(req_pbar_state - pbar_state)
pbar_state = req_pbar_state
pbar.close()
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to("cpu")
2023-04-09 13:21:02 -04:00
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
_clear_cuda_cache()
return out
def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
assert len(arr.shape) == 2
arr = arr.copy()
if offset_size is not None:
for n in range(1, arr.shape[0]):
arr[n, :] += offset_size * n
flat_arr = arr.ravel("F")
return flat_arr
COARSE_SEMANTIC_PAD_TOKEN = 12_048
COARSE_INFER_TOKEN = 12_050
def generate_coarse(
x_semantic,
history_prompt=None,
temp=0.7,
top_k=None,
top_p=None,
silent=False,
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
2023-04-25 17:49:35 -04:00
use_kv_caching=False,
2023-04-09 13:21:02 -04:00
):
"""Generate coarse audio codes from semantic tokens."""
assert (
isinstance(x_semantic, np.ndarray)
and len(x_semantic.shape) == 1
and len(x_semantic) > 0
and x_semantic.min() >= 0
and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
)
assert 60 <= max_coarse_history <= 630
assert max_coarse_history + sliding_window_len <= 1024 - 256
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None:
2023-04-21 16:14:10 -04:00
if history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
2023-07-19 19:12:27 -06:00
try:
x_semantic_history = x_history["semantic_prompt"]
x_coarse_history = x_history["coarse_prompt"]
except:
x_semantic_history = x_history["semantic"]
x_coarse_history = x_history["coarse"]
2023-04-09 13:21:02 -04:00
assert (
isinstance(x_semantic_history, np.ndarray)
and len(x_semantic_history.shape) == 1
and len(x_semantic_history) > 0
and x_semantic_history.min() >= 0
and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
and isinstance(x_coarse_history, np.ndarray)
and len(x_coarse_history.shape) == 2
and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
and x_coarse_history.shape[-1] >= 0
and x_coarse_history.min() >= 0
and x_coarse_history.max() <= CODEBOOK_SIZE - 1
and (
round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
== round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
)
)
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
# trim histories correctly
n_semantic_hist_provided = np.min(
[
max_semantic_history,
len(x_semantic_history) - len(x_semantic_history) % 2,
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
]
)
n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
# TODO: bit of a hack for time alignment (sounds better)
x_coarse_history = x_coarse_history[:-2]
else:
x_semantic_history = np.array([], dtype=np.int32)
x_coarse_history = np.array([], dtype=np.int32)
2023-04-25 17:49:35 -04:00
# load models if not yet exist
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
if "coarse" not in models:
preload_models()
model = models["coarse"]
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to(models_devices["coarse"])
2023-04-25 17:49:35 -04:00
device = next(model.parameters()).device
2023-04-09 13:21:02 -04:00
# start loop
n_steps = int(
round(
np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
* N_COARSE_CODEBOOKS
)
)
assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
x_coarse = x_coarse_history.astype(np.int32)
base_semantic_idx = len(x_semantic_history)
with _inference_mode():
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
n_window_steps = int(np.ceil(n_steps / sliding_window_len))
n_step = 0
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
# pad from right side
x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
x_in = x_in[:, :256]
x_in = F.pad(
x_in,
(0, 256 - x_in.shape[-1]),
"constant",
COARSE_SEMANTIC_PAD_TOKEN,
)
x_in = torch.hstack(
[
x_in,
torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
x_coarse_in[:, -max_coarse_history:],
]
)
kv_cache = None
2023-04-09 13:21:02 -04:00
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
2023-04-22 12:23:55 -07:00
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
2023-04-09 13:21:02 -04:00
logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
)
logit_end_idx = (
SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
)
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
if top_p is not None:
# faster to convert to numpy
original_device = relevant_logits.device
2023-04-09 13:21:02 -04:00
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
sorted_indices = np.argsort(relevant_logits)[::-1]
sorted_logits = relevant_logits[sorted_indices]
cumulative_probs = np.cumsum(softmax(sorted_logits))
sorted_indices_to_remove = cumulative_probs > top_p
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
sorted_indices_to_remove[0] = False
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
relevant_logits = torch.from_numpy(relevant_logits)
relevant_logits = relevant_logits.to(original_device)
2023-04-09 13:21:02 -04:00
if top_k is not None:
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
probs = F.softmax(relevant_logits / temp, dim=-1)
2023-04-25 17:49:35 -04:00
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
2023-04-09 13:21:02 -04:00
item_next = torch.multinomial(probs, num_samples=1)
2023-04-25 17:49:35 -04:00
probs = probs.to(inf_device)
item_next = item_next.to(inf_device)
2023-04-09 13:21:02 -04:00
item_next += logit_start_idx
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
x_in = torch.cat((x_in, item_next[None]), dim=1)
del logits, relevant_logits, probs, item_next
n_step += 1
del x_in
del x_semantic_in
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to("cpu")
2023-04-09 13:21:02 -04:00
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
del x_coarse_in
assert len(gen_coarse_arr) == n_steps
gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
for n in range(1, N_COARSE_CODEBOOKS):
gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
_clear_cuda_cache()
return gen_coarse_audio_arr
def generate_fine(
x_coarse_gen,
history_prompt=None,
temp=0.5,
silent=True,
):
"""Generate full audio codes from coarse audio codes."""
assert (
isinstance(x_coarse_gen, np.ndarray)
and len(x_coarse_gen.shape) == 2
and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
and x_coarse_gen.shape[1] > 0
and x_coarse_gen.min() >= 0
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
)
if history_prompt is not None:
2023-04-21 16:14:10 -04:00
if history_prompt.endswith(".npz"):
2023-07-19 19:12:27 -06:00
try:
x_fine_history = np.load(history_prompt)["fine_prompt"]
except:
x_fine_history = np.load(history_prompt)["fine"]
2023-04-21 16:14:10 -04:00
else:
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
2023-04-09 13:21:02 -04:00
assert (
isinstance(x_fine_history, np.ndarray)
and len(x_fine_history.shape) == 2
and x_fine_history.shape[0] == N_FINE_CODEBOOKS
and x_fine_history.shape[1] >= 0
and x_fine_history.min() >= 0
and x_fine_history.max() <= CODEBOOK_SIZE - 1
)
else:
x_fine_history = None
n_coarse = x_coarse_gen.shape[0]
2023-04-25 17:49:35 -04:00
# load models if not yet exist
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
if "fine" not in models:
preload_models()
model = models["fine"]
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to(models_devices["fine"])
2023-04-25 17:49:35 -04:00
device = next(model.parameters()).device
2023-04-09 13:21:02 -04:00
# make input arr
in_arr = np.vstack(
[
x_coarse_gen,
np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
+ CODEBOOK_SIZE, # padding
]
).astype(np.int32)
# prepend history if available (max 512)
if x_fine_history is not None:
x_fine_history = x_fine_history.astype(np.int32)
in_arr = np.hstack(
[
x_fine_history[:, -512:].astype(np.int32),
in_arr,
]
)
n_history = x_fine_history[:, -512:].shape[1]
else:
n_history = 0
n_remove_from_end = 0
# need to pad if too short (since non-causal model)
if in_arr.shape[1] < 1024:
n_remove_from_end = 1024 - in_arr.shape[1]
in_arr = np.hstack(
[
in_arr,
np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
]
)
# we can be lazy about fractional loop and just keep overwriting codebooks
n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
with _inference_mode():
in_arr = torch.tensor(in_arr.T).to(device)
for n in tqdm.tqdm(range(n_loops), disable=silent):
start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
rel_start_fill_idx = start_fill_idx - start_idx
in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
for nn in range(n_coarse, N_FINE_CODEBOOKS):
logits = model(nn, in_buffer)
if temp is None:
relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
codebook_preds = torch.argmax(relevant_logits, -1)
else:
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
probs = F.softmax(relevant_logits, dim=-1)
2023-04-25 17:49:35 -04:00
# multinomial bugged on mps: shuttle to cpu if necessary
inf_device = probs.device
if probs.device.type == "mps":
probs = probs.to("cpu")
2023-04-09 13:21:02 -04:00
codebook_preds = torch.hstack(
[
2023-04-25 17:49:35 -04:00
torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
for nnn in range(rel_start_fill_idx, 1024)
2023-04-09 13:21:02 -04:00
]
)
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
del logits, codebook_preds
# transfer over info into model_in and convert to numpy
for nn in range(n_coarse, N_FINE_CODEBOOKS):
in_arr[
start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
] = in_buffer[0, rel_start_fill_idx:, nn]
del in_buffer
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
del in_arr
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to("cpu")
2023-04-09 13:21:02 -04:00
gen_fine_arr = gen_fine_arr[:, n_history:]
if n_remove_from_end > 0:
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
_clear_cuda_cache()
return gen_fine_arr
2023-04-25 17:49:35 -04:00
def codec_decode(fine_tokens):
2023-04-09 13:21:02 -04:00
"""Turn quantized audio codes into audio array using encodec."""
2023-04-25 17:49:35 -04:00
# load models if not yet exist
global models
2023-04-25 22:42:21 -03:00
global models_devices
2023-04-25 17:49:35 -04:00
if "codec" not in models:
preload_models()
model = models["codec"]
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to(models_devices["codec"])
2023-04-25 17:49:35 -04:00
device = next(model.parameters()).device
2023-04-09 13:21:02 -04:00
arr = torch.from_numpy(fine_tokens)[None]
arr = arr.to(device)
arr = arr.transpose(0, 1)
emb = model.quantizer.decode(arr)
out = model.decoder(emb)
audio_arr = out.detach().cpu().numpy().squeeze()
del arr, emb, out
2023-04-25 21:21:52 -03:00
if OFFLOAD_CPU:
model.to("cpu")
2023-04-09 13:21:02 -04:00
return audio_arr