diff --git a/.gitignore b/.gitignore index 372c13e..c18dd8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ __pycache__/ - diff --git a/README.md b/README.md index 2ddc0ca..8a3816e 100644 --- a/README.md +++ b/README.md @@ -32,14 +32,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a ## 🤖 Usage ```python -from bark import SAMPLE_RATE, generate_audio +from bark import SAMPLE_RATE, generate_audio, preload_models from IPython.display import Audio +# download and load all models +preload_models() + +# generate audio from text text_prompt = """ Hello, my name is Suno. And, uh — and I like pizza. [laughs] But I also have other interests such as playing tic tac toe. """ audio_array = generate_audio(text_prompt) + +# play text in notebook Audio(audio_array, rate=SAMPLE_RATE) ``` @@ -83,7 +89,7 @@ audio_array = generate_audio(text_prompt) ### 🎤 Voice Presets and Voice/Audio Cloning -Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{number}`. +Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{0-9}`. ```python text_prompt = """ diff --git a/bark/__init__.py b/bark/__init__.py index 4e31eb2..e0b17c8 100644 --- a/bark/__init__.py +++ b/bark/__init__.py @@ -1,2 +1,2 @@ -from .api import generate_audio, text_to_semantic, semantic_to_waveform +from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt from .generation import SAMPLE_RATE, preload_models diff --git a/bark/api.py b/bark/api.py index b0ebc6a..e1c7556 100644 --- a/bark/api.py +++ b/bark/api.py @@ -9,6 +9,7 @@ def text_to_semantic( text: str, history_prompt: Optional[str] = None, temp: float = 0.7, + silent: bool = False, ): """Generate semantic array from text. @@ -16,6 +17,7 @@ def text_to_semantic( text: text to be turned into audio history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar Returns: numpy semantic array to be fed into `semantic_to_waveform` @@ -24,6 +26,8 @@ def text_to_semantic( text, history_prompt=history_prompt, temp=temp, + silent=silent, + use_kv_caching=True ) return x_semantic @@ -32,6 +36,8 @@ def semantic_to_waveform( semantic_tokens: np.ndarray, history_prompt: Optional[str] = None, temp: float = 0.7, + silent: bool = False, + output_full: bool = False, ): """Generate audio array from semantic input. @@ -39,29 +45,51 @@ def semantic_to_waveform( semantic_tokens: semantic token output from `text_to_semantic` history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_coarse_gen = generate_coarse( + coarse_tokens = generate_coarse( semantic_tokens, history_prompt=history_prompt, temp=temp, + silent=silent, + use_kv_caching=True ) - x_fine_gen = generate_fine( - x_coarse_gen, + fine_tokens = generate_fine( + coarse_tokens, history_prompt=history_prompt, temp=0.5, ) - audio_arr = codec_decode(x_fine_gen) + audio_arr = codec_decode(fine_tokens) + if output_full: + full_generation = { + "semantic_prompt": semantic_tokens, + "coarse_prompt": coarse_tokens, + "fine_prompt": fine_tokens, + } + return full_generation, audio_arr return audio_arr +def save_as_prompt(filepath, full_generation): + assert(filepath.endswith(".npz")) + assert(isinstance(full_generation, dict)) + assert("semantic_prompt" in full_generation) + assert("coarse_prompt" in full_generation) + assert("fine_prompt" in full_generation) + np.savez(filepath, **full_generation) + + def generate_audio( text: str, history_prompt: Optional[str] = None, text_temp: float = 0.7, waveform_temp: float = 0.7, + silent: bool = False, + output_full: bool = False, ): """Generate audio array from input text. @@ -70,10 +98,28 @@ def generate_audio( history_prompt: history choice for audio cloning text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp) - audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp) + semantic_tokens = text_to_semantic( + text, + history_prompt=history_prompt, + temp=text_temp, + silent=silent, + ) + out = semantic_to_waveform( + semantic_tokens, + history_prompt=history_prompt, + temp=waveform_temp, + silent=silent, + output_full=output_full, + ) + if output_full: + full_generation, audio_arr = out + return full_generation, audio_arr + else: + audio_arr = out return audio_arr diff --git a/bark/generation.py b/bark/generation.py index 7b5ff6e..67e1932 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -1,4 +1,5 @@ import contextlib +import gc import hashlib import os import re @@ -21,6 +22,7 @@ if ( torch.cuda.is_available() and hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") and + hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() ): autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) @@ -58,20 +60,33 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") +USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) + REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" + REMOTE_MODEL_PATHS = { + "text_small": { + "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "checksum": "b3e42bcbab23b688355cd44128c4cdd3", + }, + "coarse_small": { + "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "checksum": "5fe964825e3b0321f9d5f3857b89194d", + }, + "fine_small": { + "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "checksum": "5428d1befe05be2ba32195496e58dc90", + }, "text": { - "path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")), + "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), "checksum": "54afa89d65e318d4f5f80e8e8799026a", }, "coarse": { - "path": os.environ.get( - "SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt") - ), + "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", }, "fine": { - "path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")), + "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), "checksum": "59d184ed44e3650774a2f0503a48a97b", }, } @@ -98,8 +113,9 @@ def _md5(fname): return hash_md5.hexdigest() -def _get_ckpt_path(model_type): - model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) +def _get_ckpt_path(model_type, use_small=False): + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) return os.path.join(CACHE_DIR, f"{model_name}.pt") @@ -115,9 +131,9 @@ def _parse_s3_filepath(s3_filepath): def _download(from_s3_path, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) response = requests.get(from_s3_path, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) with open(to_local_path, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) @@ -165,11 +181,12 @@ def clean_models(model_key=None): if k in models: del models[k] _clear_cuda_cache() + gc.collect() -def _load_model(ckpt_path, device, model_type="text"): +def _load_model(ckpt_path, device, use_small=False, model_type="text"): if "cuda" not in device: - logger.warning("No GPU being used. Careful, Inference might be extremely slow!") + logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": ConfigClass = GPTConfig ModelClass = GPT @@ -181,15 +198,17 @@ def _load_model(ckpt_path, device, model_type="text"): ModelClass = FineGPT else: raise NotImplementedError() + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_info = REMOTE_MODEL_PATHS[model_key] if ( os.path.exists(ckpt_path) and - _md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] + _md5(ckpt_path) != model_info["checksum"] ): - logger.warning(f"found outdated {model_type} model, removing...") + logger.warning(f"found outdated {model_type} model, removing.") os.remove(ckpt_path) if not os.path.exists(ckpt_path): - logger.info(f"{model_type} model not found, downloading...") - _download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) + logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") + _download(model_info["path"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] @@ -239,8 +258,8 @@ def _load_codec_model(device): return model -def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): - _load_model_f = funcy.partial(_load_model, model_type=model_type) +def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"): + _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small) if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models @@ -250,8 +269,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex device = "cuda" model_key = str(device) + f"__{model_type}" if model_key not in models or force_reload: - if ckpt_path is None: - ckpt_path = _get_ckpt_path(model_type) + ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model @@ -272,17 +290,29 @@ def load_codec_model(use_gpu=True, force_reload=False): return models[model_key] -def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True): +def preload_models( + text_use_gpu=True, + text_use_small=False, + coarse_use_gpu=True, + coarse_use_small=False, + fine_use_gpu=True, + fine_use_small=False, + codec_use_gpu=True, + force_reload=False, +): _ = load_model( - ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload ) _ = load_model( - ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True + model_type="coarse", + use_gpu=coarse_use_gpu, + use_small=coarse_use_small, + force_reload=force_reload, ) _ = load_model( - ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload ) - _ = load_codec_model(use_gpu=use_gpu, force_reload=True) + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) #### @@ -320,15 +350,19 @@ def generate_text_semantic( max_gen_duration_s=None, allow_early_stop=True, model=None, + use_kv_caching=False ): """Generate semantic tokens from text.""" assert isinstance(text, str) text = _normalize_whitespace(text) assert len(text.strip()) > 0 if history_prompt is not None: - semantic_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["semantic_prompt"] + if history_prompt.endswith(".npz"): + semantic_history = np.load(history_prompt)["semantic_prompt"] + else: + semantic_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["semantic_prompt"] assert ( isinstance(semantic_history, np.ndarray) and len(semantic_history.shape) == 1 @@ -377,8 +411,14 @@ def generate_text_semantic( pbar = tqdm.tqdm(disable=silent, total=100) pbar_state = 0 tot_generated_duration_s = 0 + kv_cache = None for n in range(n_tot_steps): - logits = model(x, merge_context=True) + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -455,6 +495,7 @@ def generate_coarse( max_coarse_history=630, # min 60 (faster), max 630 (more context) sliding_window_len=60, model=None, + use_kv_caching=False ): """Generate coarse audio codes from semantic tokens.""" assert ( @@ -469,9 +510,12 @@ def generate_coarse( semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) if history_prompt is not None: - x_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - ) + if history_prompt.endswith(".npz"): + x_history = np.load(history_prompt) + else: + x_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + ) x_semantic_history = x_history["semantic_prompt"] x_coarse_history = x_history["coarse_prompt"] assert ( @@ -545,11 +589,18 @@ def generate_coarse( x_coarse_in[:, -max_coarse_history:], ] ) + kv_cache = None for _ in range(sliding_window_len): if n_step >= n_steps: continue is_major_step = n_step % N_COARSE_CODEBOOKS == 0 - logits = model(x_in) + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) @@ -611,9 +662,12 @@ def generate_fine( and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 ) if history_prompt is not None: - x_fine_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["fine_prompt"] + if history_prompt.endswith(".npz"): + x_fine_history = np.load(history_prompt)["fine_prompt"] + else: + x_fine_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["fine_prompt"] assert ( isinstance(x_fine_history, np.ndarray) and len(x_fine_history.shape) == 2 diff --git a/bark/model.py b/bark/model.py index bbf9b68..bb99932 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, past_kv=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,14 +52,36 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. + is_causal = False + else: + is_causal = True + + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) @@ -67,7 +89,7 @@ class CausalSelfAttention(nn.Module): # output projection y = self.resid_dropout(self.c_proj(y)) - return y + return (y, present) class MLP(nn.Module): @@ -95,10 +117,11 @@ class Block(nn.Module): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) + x = x + attn_output x = x + self.mlp(self.ln_2(x)) - return x + return (x, prev_kvs) @dataclass class GPTConfig: @@ -142,33 +165,55 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if merge_context: - assert(idx.shape[1] >= 256+256+1) - t = idx.shape[1] - 256 - else: - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - - # forward the GPT model itself - if merge_context: - tok_emb = torch.cat([ - self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), - self.transformer.wte(idx[:,256+256:]) - ], dim=1) - else: + if past_kv is not None: + assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + else: + if merge_context: + assert(idx.shape[1] >= 256+256+1) + t = idx.shape[1] - 256 + else: + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat([ + self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), + self.transformer.wte(idx[:,256+256:]) + ], dim=1) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_kv is None: + past_length = 0 + past_kv = tuple([None] * len(self.transformer.h)) + else: + past_length = past_kv[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) - for block in self.transformer.h: - x = block(x) + + new_kv = () if use_cache else None + + for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) + + if use_cache: + new_kv = new_kv + (kv,) + x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return logits + return (logits, new_kv) diff --git a/model-card.md b/model-card.md index 2625250..5ead3a9 100644 --- a/model-card.md +++ b/model-card.md @@ -8,7 +8,7 @@ The following is additional information about the models released here. Bark is a series of three transformer models that turn text into audio. ### Text to semantic tokens - - Input: text, tokenized with [BERT tokenizer from huggingface](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) + - Input: text, tokenized with [BERT tokenizer from Hugging Face](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) - Output: semantic tokens that encode the audio to be generated ### Semantic to coarse tokens