Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Francis LaBounty
2023-04-22 16:28:02 -06:00
7 changed files with 223 additions and 73 deletions

View File

@@ -1,4 +1,5 @@
import contextlib
import gc
import hashlib
import os
import re
@@ -21,6 +22,7 @@ if (
torch.cuda.is_available() and
hasattr(torch.cuda, "amp") and
hasattr(torch.cuda.amp, "autocast") and
hasattr(torch.cuda, "is_bf16_supported") and
torch.cuda.is_bf16_supported()
):
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
@@ -58,20 +60,33 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
REMOTE_MODEL_PATHS = {
"text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
"text": {
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
},
"coarse": {
"path": os.environ.get(
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
},
"fine": {
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b",
},
}
@@ -98,8 +113,9 @@ def _md5(fname):
return hash_md5.hexdigest()
def _get_ckpt_path(model_type):
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
def _get_ckpt_path(model_type, use_small=False):
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt")
@@ -115,9 +131,9 @@ def _parse_s3_filepath(s3_filepath):
def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 Kibibyte
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(to_local_path, "wb") as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
@@ -165,11 +181,12 @@ def clean_models(model_key=None):
if k in models:
del models[k]
_clear_cuda_cache()
gc.collect()
def _load_model(ckpt_path, device, model_type="text"):
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if "cuda" not in device:
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text":
ConfigClass = GPTConfig
ModelClass = GPT
@@ -181,15 +198,17 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT
else:
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
_md5(ckpt_path) != model_info["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing...")
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading...")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
@@ -239,8 +258,8 @@ def _load_codec_model(device):
return model
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type)
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError()
global models
@@ -250,8 +269,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
device = "cuda"
model_key = str(device) + f"__{model_type}"
if model_key not in models or force_reload:
if ckpt_path is None:
ckpt_path = _get_ckpt_path(model_type)
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
@@ -272,17 +290,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
return models[model_key]
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
def preload_models(
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
coarse_use_small=False,
fine_use_gpu=True,
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
):
_ = load_model(
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
)
_ = load_model(
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
)
_ = load_model(
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
)
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
####
@@ -320,15 +350,19 @@ def generate_text_semantic(
max_gen_duration_s=None,
allow_early_stop=True,
model=None,
use_kv_caching=False
):
"""Generate semantic tokens from text."""
assert isinstance(text, str)
text = _normalize_whitespace(text)
assert len(text.strip()) > 0
if history_prompt is not None:
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
if history_prompt.endswith(".npz"):
semantic_history = np.load(history_prompt)["semantic_prompt"]
else:
semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"]
assert (
isinstance(semantic_history, np.ndarray)
and len(semantic_history.shape) == 1
@@ -377,8 +411,14 @@ def generate_text_semantic(
pbar = tqdm.tqdm(disable=silent, total=100)
pbar_state = 0
tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps):
logits = model(x, merge_context=True)
if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop:
relevant_logits = torch.hstack(
@@ -455,6 +495,7 @@ def generate_coarse(
max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60,
model=None,
use_kv_caching=False
):
"""Generate coarse audio codes from semantic tokens."""
assert (
@@ -469,9 +510,12 @@ def generate_coarse(
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None:
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
if history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)
x_semantic_history = x_history["semantic_prompt"]
x_coarse_history = x_history["coarse_prompt"]
assert (
@@ -545,11 +589,18 @@ def generate_coarse(
x_coarse_in[:, -max_coarse_history:],
]
)
kv_cache = None
for _ in range(sliding_window_len):
if n_step >= n_steps:
continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
logits = model(x_in)
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
)
@@ -611,9 +662,12 @@ def generate_fine(
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
)
if history_prompt is not None:
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
if history_prompt.endswith(".npz"):
x_fine_history = np.load(history_prompt)["fine_prompt"]
else:
x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"]
assert (
isinstance(x_fine_history, np.ndarray)
and len(x_fine_history.shape) == 2