From 9751cfbfc49a91fb9481f553cfe8d56f0c9cdbbb Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Fri, 21 Apr 2023 15:13:16 -0400 Subject: [PATCH 1/9] small updates --- README.md | 2 +- bark/api.py | 16 ++++++++++++++-- bark/generation.py | 12 ++++++------ model-card.md | 2 +- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 7b31cfd..ece2a20 100644 --- a/README.md +++ b/README.md @@ -72,7 +72,7 @@ audio_array = generate_audio(text_prompt) ### 🎤 Voice Presets and Voice/Audio Cloning -Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{number}`. +Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{0-9}`. ```python text_prompt = """ diff --git a/bark/api.py b/bark/api.py index b0ebc6a..300459c 100644 --- a/bark/api.py +++ b/bark/api.py @@ -9,6 +9,7 @@ def text_to_semantic( text: str, history_prompt: Optional[str] = None, temp: float = 0.7, + silent: bool = False, ): """Generate semantic array from text. @@ -16,6 +17,7 @@ def text_to_semantic( text: text to be turned into audio history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar Returns: numpy semantic array to be fed into `semantic_to_waveform` @@ -24,6 +26,7 @@ def text_to_semantic( text, history_prompt=history_prompt, temp=temp, + silent=silent, ) return x_semantic @@ -32,6 +35,7 @@ def semantic_to_waveform( semantic_tokens: np.ndarray, history_prompt: Optional[str] = None, temp: float = 0.7, + silent: bool = False, ): """Generate audio array from semantic input. @@ -39,6 +43,7 @@ def semantic_to_waveform( semantic_tokens: semantic token output from `text_to_semantic` history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar Returns: numpy audio array at sample frequency 24khz @@ -47,6 +52,7 @@ def semantic_to_waveform( semantic_tokens, history_prompt=history_prompt, temp=temp, + silent=silent, ) x_fine_gen = generate_fine( x_coarse_gen, @@ -62,6 +68,7 @@ def generate_audio( history_prompt: Optional[str] = None, text_temp: float = 0.7, waveform_temp: float = 0.7, + silent: bool = False, ): """Generate audio array from input text. @@ -70,10 +77,15 @@ def generate_audio( history_prompt: history choice for audio cloning text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) + silent: disable progress bar Returns: numpy audio array at sample frequency 24khz """ - x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp) - audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp) + x_semantic = text_to_semantic( + text, history_prompt=history_prompt, temp=text_temp, silent=silent, + ) + audio_arr = semantic_to_waveform( + x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent, + ) return audio_arr diff --git a/bark/generation.py b/bark/generation.py index 2b4fabe..c97dead 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -137,9 +137,9 @@ def _parse_s3_filepath(s3_filepath): def _download(from_s3_path, to_local_path): os.makedirs(CACHE_DIR, exist_ok=True) response = requests.get(from_s3_path, stream=True) - total_size_in_bytes = int(response.headers.get('content-length', 0)) - block_size = 1024 # 1 Kibibyte - progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) + total_size_in_bytes = int(response.headers.get("content-length", 0)) + block_size = 1024 + progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) with open(to_local_path, "wb") as file: for data in response.iter_content(block_size): progress_bar.update(len(data)) @@ -191,7 +191,7 @@ def clean_models(model_key=None): def _load_model(ckpt_path, device, model_type="text"): if "cuda" not in device: - logger.warning("No GPU being used. Careful, Inference might be extremely slow!") + logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": ConfigClass = GPTConfig ModelClass = GPT @@ -207,10 +207,10 @@ def _load_model(ckpt_path, device, model_type="text"): os.path.exists(ckpt_path) and _md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] ): - logger.warning(f"found outdated {model_type} model, removing...") + logger.warning(f"found outdated {model_type} model, removing.") os.remove(ckpt_path) if not os.path.exists(ckpt_path): - logger.info(f"{model_type} model not found, downloading...") + logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") _download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack diff --git a/model-card.md b/model-card.md index 2625250..5ead3a9 100644 --- a/model-card.md +++ b/model-card.md @@ -8,7 +8,7 @@ The following is additional information about the models released here. Bark is a series of three transformer models that turn text into audio. ### Text to semantic tokens - - Input: text, tokenized with [BERT tokenizer from huggingface](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) + - Input: text, tokenized with [BERT tokenizer from Hugging Face](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) - Output: semantic tokens that encode the audio to be generated ### Semantic to coarse tokens From c3724301125c3ba4401c89958ca8612201282ec0 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Fri, 21 Apr 2023 15:31:36 -0400 Subject: [PATCH 2/9] add option for smaller models --- bark/generation.py | 49 +++++++++++++++++++++++++++++++--------------- 1 file changed, 33 insertions(+), 16 deletions(-) diff --git a/bark/generation.py b/bark/generation.py index c97dead..fa54388 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -21,6 +21,7 @@ if ( torch.cuda.is_available() and hasattr(torch.cuda, "amp") and hasattr(torch.cuda.amp, "autocast") and + hasattr(torch.cuda, "is_bf16_supported") and torch.cuda.is_bf16_supported() ): autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) @@ -80,23 +81,39 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") +USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) + REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" -REMOTE_MODEL_PATHS = { - "text": { - "path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")), - "checksum": "54afa89d65e318d4f5f80e8e8799026a", - }, - "coarse": { - "path": os.environ.get( - "SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt") - ), - "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", - }, - "fine": { - "path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")), - "checksum": "59d184ed44e3650774a2f0503a48a97b", - }, -} +if USE_SMALL_MODELS: + REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "checksum": "b3e42bcbab23b688355cd44128c4cdd3", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "checksum": "5fe964825e3b0321f9d5f3857b89194d", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "checksum": "5428d1befe05be2ba32195496e58dc90", + }, + } +else: + REMOTE_MODEL_PATHS = { + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, + } if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): From 7d39f48c7a08c45d41f9a2f6c663ca35f9a2e944 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Fri, 21 Apr 2023 16:14:10 -0400 Subject: [PATCH 3/9] allow using unconditional as prompts --- bark/api.py | 43 ++++++++++++++++++++++++++++++++++++------- bark/generation.py | 33 +++++++++++++++++++++------------ 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/bark/api.py b/bark/api.py index 300459c..8033dc6 100644 --- a/bark/api.py +++ b/bark/api.py @@ -36,6 +36,7 @@ def semantic_to_waveform( history_prompt: Optional[str] = None, temp: float = 0.7, silent: bool = False, + output_full: bool = False, ): """Generate audio array from semantic input. @@ -44,31 +45,49 @@ def semantic_to_waveform( history_prompt: history choice for audio cloning temp: generation temperature (1.0 more diverse, 0.0 more conservative) silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_coarse_gen = generate_coarse( + coarse_tokens = generate_coarse( semantic_tokens, history_prompt=history_prompt, temp=temp, silent=silent, ) - x_fine_gen = generate_fine( - x_coarse_gen, + fine_tokens = generate_fine( + coarse_tokens, history_prompt=history_prompt, temp=0.5, ) - audio_arr = codec_decode(x_fine_gen) + audio_arr = codec_decode(fine_tokens) + if output_full: + full_generation = { + "semantic_prompt": semantic_tokens, + "coarse_prompt": coarse_tokens, + "fine_prompt": fine_tokens, + } + return full_generation, audio_arr return audio_arr +def save_as_prompt(filepath, full_generation): + assert(filepath.endswith(".npz")) + assert(isinstance(full_generation, dict)) + assert("semantic_prompt" in full_generation) + assert("coarse_prompt" in full_generation) + assert("fine_prompt" in full_generation) + np.savez(filepath, **full_generation) + + def generate_audio( text: str, history_prompt: Optional[str] = None, text_temp: float = 0.7, waveform_temp: float = 0.7, silent: bool = False, + output_full: bool = False, ): """Generate audio array from input text. @@ -78,14 +97,24 @@ def generate_audio( text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) silent: disable progress bar + output_full: return full generation to be used as a history prompt Returns: numpy audio array at sample frequency 24khz """ - x_semantic = text_to_semantic( + semantic_tokens = text_to_semantic( text, history_prompt=history_prompt, temp=text_temp, silent=silent, ) - audio_arr = semantic_to_waveform( - x_semantic, history_prompt=history_prompt, temp=waveform_temp, silent=silent, + out = semantic_to_waveform( + semantic_tokens, + history_prompt=history_prompt, + temp=waveform_temp, + silent=silent, + output_full=output_full, ) + if output_full: + full_generation, audio_arr = out + return full_generation, audio_arr + else: + audio_arr = out return audio_arr diff --git a/bark/generation.py b/bark/generation.py index fa54388..b5476bc 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -365,10 +365,13 @@ def generate_text_semantic( text = _normalize_whitespace(text) assert len(text.strip()) > 0 if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - semantic_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["semantic_prompt"] + if history_prompt.endswith(".npz"): + semantic_history = np.load(history_prompt)["semantic_prompt"] + else: + assert (history_prompt in ALLOWED_PROMPTS) + semantic_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["semantic_prompt"] assert ( isinstance(semantic_history, np.ndarray) and len(semantic_history.shape) == 1 @@ -509,10 +512,13 @@ def generate_coarse( semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - x_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - ) + if history_prompt.endswith(".npz"): + x_history = np.load(history_prompt) + else: + assert (history_prompt in ALLOWED_PROMPTS) + x_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + ) x_semantic_history = x_history["semantic_prompt"] x_coarse_history = x_history["coarse_prompt"] assert ( @@ -652,10 +658,13 @@ def generate_fine( and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 ) if history_prompt is not None: - assert (history_prompt in ALLOWED_PROMPTS) - x_fine_history = np.load( - os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") - )["fine_prompt"] + if history_prompt.endswith(".npz"): + x_fine_history = np.load(history_prompt)["fine_prompt"] + else: + assert (history_prompt in ALLOWED_PROMPTS) + x_fine_history = np.load( + os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") + )["fine_prompt"] assert ( isinstance(x_fine_history, np.ndarray) and len(x_fine_history.shape) == 2 From 874af1bae9a74324b1fff5573963373c0016f0e0 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Fri, 21 Apr 2023 16:26:07 -0400 Subject: [PATCH 4/9] convenience import --- bark/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bark/__init__.py b/bark/__init__.py index 4e31eb2..e0b17c8 100644 --- a/bark/__init__.py +++ b/bark/__init__.py @@ -1,2 +1,2 @@ -from .api import generate_audio, text_to_semantic, semantic_to_waveform +from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt from .generation import SAMPLE_RATE, preload_models From 15606ed12fd07bfea79bed6413615adf002da983 Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Thu, 20 Apr 2023 18:39:14 -0700 Subject: [PATCH 5/9] Add k/v caching for autoregressive generation --- .gitignore | 2 +- bark/api.py | 8 +++- bark/generation.py | 19 +++++++++- bark/model.py | 95 +++++++++++++++++++++++++++++++++------------- 4 files changed, 94 insertions(+), 30 deletions(-) diff --git a/.gitignore b/.gitignore index 372c13e..48e4ceb 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,2 @@ __pycache__/ - +.venv \ No newline at end of file diff --git a/bark/api.py b/bark/api.py index 8033dc6..8231616 100644 --- a/bark/api.py +++ b/bark/api.py @@ -10,6 +10,7 @@ def text_to_semantic( history_prompt: Optional[str] = None, temp: float = 0.7, silent: bool = False, + use_kv_caching = False, ): """Generate semantic array from text. @@ -27,6 +28,7 @@ def text_to_semantic( history_prompt=history_prompt, temp=temp, silent=silent, + use_kv_caching=use_kv_caching ) return x_semantic @@ -37,6 +39,7 @@ def semantic_to_waveform( temp: float = 0.7, silent: bool = False, output_full: bool = False, + use_kv_caching = False ): """Generate audio array from semantic input. @@ -55,6 +58,7 @@ def semantic_to_waveform( history_prompt=history_prompt, temp=temp, silent=silent, + use_kv_caching=use_kv_caching ) fine_tokens = generate_fine( coarse_tokens, @@ -88,6 +92,7 @@ def generate_audio( waveform_temp: float = 0.7, silent: bool = False, output_full: bool = False, + use_kv_caching = False ): """Generate audio array from input text. @@ -103,7 +108,7 @@ def generate_audio( numpy audio array at sample frequency 24khz """ semantic_tokens = text_to_semantic( - text, history_prompt=history_prompt, temp=text_temp, silent=silent, + text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching ) out = semantic_to_waveform( semantic_tokens, @@ -111,6 +116,7 @@ def generate_audio( temp=waveform_temp, silent=silent, output_full=output_full, + use_kv_caching=use_kv_caching ) if output_full: full_generation, audio_arr = out diff --git a/bark/generation.py b/bark/generation.py index b5476bc..5753125 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -359,6 +359,7 @@ def generate_text_semantic( max_gen_duration_s=None, allow_early_stop=True, model=None, + use_kv_caching=False ): """Generate semantic tokens from text.""" assert isinstance(text, str) @@ -420,8 +421,14 @@ def generate_text_semantic( pbar = tqdm.tqdm(disable=silent, total=100) pbar_state = 0 tot_generated_duration_s = 0 + kv_cache = None for n in range(n_tot_steps): - logits = model(x, merge_context=True) + if use_kv_caching and kv_cache is not None: + x_input = x[:, [-1]] + else: + x_input = x + + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -498,6 +505,7 @@ def generate_coarse( max_coarse_history=630, # min 60 (faster), max 630 (more context) sliding_window_len=60, model=None, + use_kv_caching=False ): """Generate coarse audio codes from semantic tokens.""" assert ( @@ -592,11 +600,18 @@ def generate_coarse( x_coarse_in[:, -max_coarse_history:], ] ) + kv_cache = None for _ in range(sliding_window_len): if n_step >= n_steps: continue is_major_step = n_step % N_COARSE_CODEBOOKS == 0 - logits = model(x_in) + + if use_kv_caching and kv_cache is not None: + x_input = x_in[:, [-1]] + else: + x_input = x_in + + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) diff --git a/bark/model.py b/bark/model.py index bbf9b68..463557c 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x): + def forward(self, x, layer_past=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,14 +52,34 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + k = torch.cat((past_key, k), dim=-2) + v = torch.cat((past_value, v), dim=-2) + + FULL_T = k.shape[-2] + + if use_cache is True: + present = (k, v) + else: + present = None + # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) + if layer_past is not None: + # in theory the attention is still causal but because we're computing it incrementally, + # the last query can attend on all previous keys/values, which which is equivalent to non-causal + is_causal = False + else: + is_causal = True + + y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal) else: # manual implementation of attention att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) - att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) + att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf')) att = F.softmax(att, dim=-1) att = self.attn_dropout(att) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) @@ -67,7 +87,7 @@ class CausalSelfAttention(nn.Module): # output projection y = self.resid_dropout(self.c_proj(y)) - return y + return (y, present) class MLP(nn.Module): @@ -95,10 +115,11 @@ class Block(nn.Module): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x): - x = x + self.attn(self.ln_1(x)) + def forward(self, x, layer_past=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache) + x = x + attn_output x = x + self.mlp(self.ln_2(x)) - return x + return (x, prev_kvs) @dataclass class GPTConfig: @@ -142,33 +163,55 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False): + def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if merge_context: - assert(idx.shape[1] >= 256+256+1) - t = idx.shape[1] - 256 - else: - assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" - - # forward the GPT model itself - if merge_context: - tok_emb = torch.cat([ - self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), - self.transformer.wte(idx[:,256+256:]) - ], dim=1) - else: + if past_key_values is not None: + assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + else: + if merge_context: + assert(idx.shape[1] >= 256+256+1) + t = idx.shape[1] - 256 + else: + assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" + + # forward the GPT model itself + if merge_context: + tok_emb = torch.cat([ + self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]), + self.transformer.wte(idx[:,256+256:]) + ], dim=1) + else: + tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.transformer.h)) + else: + past_length = past_key_values[0][0].size(-2) + + if position_ids is None: + position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) + position_ids = position_ids.unsqueeze(0) # shape (1, t) + assert position_ids.shape == (1, t) + + pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd) - pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) x = self.transformer.drop(tok_emb + pos_emb) - for block in self.transformer.h: - x = block(x) + + presents = () if use_cache else None + + for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): + x, kv = block(x, layer_past=layer_past, use_cache=use_cache) + + if use_cache: + presents = presents + (kv,) + x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return logits + return (logits, presents) From bee9e030802e612e1941d68822680b5be6895d7e Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Sat, 22 Apr 2023 12:23:55 -0700 Subject: [PATCH 6/9] Rename variables and add comments --- .gitignore | 3 +-- bark/generation.py | 4 ++-- bark/model.py | 40 +++++++++++++++++++++------------------- 3 files changed, 24 insertions(+), 23 deletions(-) diff --git a/.gitignore b/.gitignore index 48e4ceb..ba0430d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1 @@ -__pycache__/ -.venv \ No newline at end of file +__pycache__/ \ No newline at end of file diff --git a/bark/generation.py b/bark/generation.py index 5753125..4e860d8 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -428,7 +428,7 @@ def generate_text_semantic( else: x_input = x - logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache) relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] if allow_early_stop: relevant_logits = torch.hstack( @@ -611,7 +611,7 @@ def generate_coarse( else: x_input = x_in - logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_key_values=kv_cache) + logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache) logit_start_idx = ( SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE ) diff --git a/bark/model.py b/bark/model.py index 463557c..bb99932 100644 --- a/bark/model.py +++ b/bark/model.py @@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module): self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size)) - def forward(self, x, layer_past=None, use_cache=False): + def forward(self, x, past_kv=None, use_cache=False): B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) # calculate query, key, values for all heads in batch and move head forward to be the batch dim @@ -52,9 +52,9 @@ class CausalSelfAttention(nn.Module): q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) - if layer_past is not None: - past_key = layer_past[0] - past_value = layer_past[1] + if past_kv is not None: + past_key = past_kv[0] + past_value = past_kv[1] k = torch.cat((past_key, k), dim=-2) v = torch.cat((past_value, v), dim=-2) @@ -68,9 +68,11 @@ class CausalSelfAttention(nn.Module): # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) if self.flash: # efficient attention using Flash Attention CUDA kernels - if layer_past is not None: - # in theory the attention is still causal but because we're computing it incrementally, - # the last query can attend on all previous keys/values, which which is equivalent to non-causal + if past_kv is not None: + # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains + # the query for the last token. scaled_dot_product_attention interprets this as the first token in the + # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so + # to work around this we set is_causal=False. is_causal = False else: is_causal = True @@ -115,8 +117,8 @@ class Block(nn.Module): self.mlp = MLP(config) self.layer_idx = layer_idx - def forward(self, x, layer_past=None, use_cache=False): - attn_output, prev_kvs = self.attn(self.ln_1(x), layer_past=layer_past, use_cache=use_cache) + def forward(self, x, past_kv=None, use_cache=False): + attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache) x = x + attn_output x = x + self.mlp(self.ln_2(x)) return (x, prev_kvs) @@ -163,10 +165,10 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_key_values=None, position_ids=None, use_cache=False): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): device = idx.device b, t = idx.size() - if past_key_values is not None: + if past_kv is not None: assert t == 1 tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) else: @@ -185,11 +187,11 @@ class GPT(nn.Module): else: tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) - if past_key_values is None: + if past_kv is None: past_length = 0 - past_key_values = tuple([None] * len(self.transformer.h)) + past_kv = tuple([None] * len(self.transformer.h)) else: - past_length = past_key_values[0][0].size(-2) + past_length = past_kv[0][0].size(-2) if position_ids is None: position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device) @@ -201,17 +203,17 @@ class GPT(nn.Module): x = self.transformer.drop(tok_emb + pos_emb) - presents = () if use_cache else None + new_kv = () if use_cache else None - for i, (block, layer_past) in enumerate(zip(self.transformer.h, past_key_values)): - x, kv = block(x, layer_past=layer_past, use_cache=use_cache) + for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)): + x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache) if use_cache: - presents = presents + (kv,) + new_kv = new_kv + (kv,) x = self.transformer.ln_f(x) # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim - return (logits, presents) + return (logits, new_kv) From acfd65b1a96e6cb7893c0059984520693d1381b3 Mon Sep 17 00:00:00 2001 From: Zygimantas Straznickas Date: Sat, 22 Apr 2023 12:27:47 -0700 Subject: [PATCH 7/9] Add newline to .gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index ba0430d..c18dd8d 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1 @@ -__pycache__/ \ No newline at end of file +__pycache__/ From 009ff7cb6217d2c420765c5e98bc8345bbd66849 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Sat, 22 Apr 2023 15:42:30 -0400 Subject: [PATCH 8/9] make kv caching default in inference --- bark/api.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/bark/api.py b/bark/api.py index 8231616..e1c7556 100644 --- a/bark/api.py +++ b/bark/api.py @@ -10,7 +10,6 @@ def text_to_semantic( history_prompt: Optional[str] = None, temp: float = 0.7, silent: bool = False, - use_kv_caching = False, ): """Generate semantic array from text. @@ -28,7 +27,7 @@ def text_to_semantic( history_prompt=history_prompt, temp=temp, silent=silent, - use_kv_caching=use_kv_caching + use_kv_caching=True ) return x_semantic @@ -39,7 +38,6 @@ def semantic_to_waveform( temp: float = 0.7, silent: bool = False, output_full: bool = False, - use_kv_caching = False ): """Generate audio array from semantic input. @@ -58,7 +56,7 @@ def semantic_to_waveform( history_prompt=history_prompt, temp=temp, silent=silent, - use_kv_caching=use_kv_caching + use_kv_caching=True ) fine_tokens = generate_fine( coarse_tokens, @@ -92,7 +90,6 @@ def generate_audio( waveform_temp: float = 0.7, silent: bool = False, output_full: bool = False, - use_kv_caching = False ): """Generate audio array from input text. @@ -108,7 +105,10 @@ def generate_audio( numpy audio array at sample frequency 24khz """ semantic_tokens = text_to_semantic( - text, history_prompt=history_prompt, temp=text_temp, silent=silent, use_kv_caching=use_kv_caching + text, + history_prompt=history_prompt, + temp=text_temp, + silent=silent, ) out = semantic_to_waveform( semantic_tokens, @@ -116,7 +116,6 @@ def generate_audio( temp=waveform_temp, silent=silent, output_full=output_full, - use_kv_caching=use_kv_caching ) if output_full: full_generation, audio_arr = out From 8313b570f43319713ca7600602c5df823e4719f3 Mon Sep 17 00:00:00 2001 From: Georg Kucsko Date: Sat, 22 Apr 2023 17:09:20 -0400 Subject: [PATCH 9/9] simplify --- README.md | 8 +++- bark/generation.py | 101 +++++++++++++++++++++++++-------------------- 2 files changed, 64 insertions(+), 45 deletions(-) diff --git a/README.md b/README.md index ece2a20..a44887b 100644 --- a/README.md +++ b/README.md @@ -21,14 +21,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a ## 🤖 Usage ```python -from bark import SAMPLE_RATE, generate_audio +from bark import SAMPLE_RATE, generate_audio, preload_models from IPython.display import Audio +# download and load all models +preload_models() + +# generate audio from text text_prompt = """ Hello, my name is Suno. And, uh — and I like pizza. [laughs] But I also have other interests such as playing tic tac toe. """ audio_array = generate_audio(text_prompt) + +# play text in notebook Audio(audio_array, rate=SAMPLE_RATE) ``` diff --git a/bark/generation.py b/bark/generation.py index 4e860d8..4aa805e 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -1,4 +1,5 @@ import contextlib +import gc import hashlib import os import re @@ -84,36 +85,33 @@ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" -if USE_SMALL_MODELS: - REMOTE_MODEL_PATHS = { - "text": { - "path": os.path.join(REMOTE_BASE_URL, "text.pt"), - "checksum": "b3e42bcbab23b688355cd44128c4cdd3", - }, - "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), - "checksum": "5fe964825e3b0321f9d5f3857b89194d", - }, - "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), - "checksum": "5428d1befe05be2ba32195496e58dc90", - }, - } -else: - REMOTE_MODEL_PATHS = { - "text": { - "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), - "checksum": "54afa89d65e318d4f5f80e8e8799026a", - }, - "coarse": { - "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), - "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", - }, - "fine": { - "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), - "checksum": "59d184ed44e3650774a2f0503a48a97b", - }, - } + +REMOTE_MODEL_PATHS = { + "text_small": { + "path": os.path.join(REMOTE_BASE_URL, "text.pt"), + "checksum": "b3e42bcbab23b688355cd44128c4cdd3", + }, + "coarse_small": { + "path": os.path.join(REMOTE_BASE_URL, "coarse.pt"), + "checksum": "5fe964825e3b0321f9d5f3857b89194d", + }, + "fine_small": { + "path": os.path.join(REMOTE_BASE_URL, "fine.pt"), + "checksum": "5428d1befe05be2ba32195496e58dc90", + }, + "text": { + "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"), + "checksum": "54afa89d65e318d4f5f80e8e8799026a", + }, + "coarse": { + "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"), + "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", + }, + "fine": { + "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"), + "checksum": "59d184ed44e3650774a2f0503a48a97b", + }, +} if not hasattr(torch.nn.functional, 'scaled_dot_product_attention'): @@ -137,8 +135,9 @@ def _md5(fname): return hash_md5.hexdigest() -def _get_ckpt_path(model_type): - model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) +def _get_ckpt_path(model_type, use_small=False): + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"]) return os.path.join(CACHE_DIR, f"{model_name}.pt") @@ -204,9 +203,10 @@ def clean_models(model_key=None): if k in models: del models[k] _clear_cuda_cache() + gc.collect() -def _load_model(ckpt_path, device, model_type="text"): +def _load_model(ckpt_path, device, use_small=False, model_type="text"): if "cuda" not in device: logger.warning("No GPU being used. Careful, inference might be extremely slow!") if model_type == "text": @@ -220,15 +220,17 @@ def _load_model(ckpt_path, device, model_type="text"): ModelClass = FineGPT else: raise NotImplementedError() + model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type + model_info = REMOTE_MODEL_PATHS[model_key] if ( os.path.exists(ckpt_path) and - _md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] + _md5(ckpt_path) != model_info["checksum"] ): logger.warning(f"found outdated {model_type} model, removing.") os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") - _download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) + _download(model_info["path"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack model_args = checkpoint["model_args"] @@ -278,8 +280,8 @@ def _load_codec_model(device): return model -def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): - _load_model_f = funcy.partial(_load_model, model_type=model_type) +def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"): + _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small) if model_type not in ("text", "coarse", "fine"): raise NotImplementedError() global models @@ -289,8 +291,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex device = "cuda" model_key = str(device) + f"__{model_type}" if model_key not in models or force_reload: - if ckpt_path is None: - ckpt_path = _get_ckpt_path(model_type) + ckpt_path = _get_ckpt_path(model_type, use_small=use_small) clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model @@ -311,17 +312,29 @@ def load_codec_model(use_gpu=True, force_reload=False): return models[model_key] -def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True): +def preload_models( + text_use_gpu=True, + text_use_small=False, + coarse_use_gpu=True, + coarse_use_small=False, + fine_use_gpu=True, + fine_use_small=False, + codec_use_gpu=True, + force_reload=False, +): _ = load_model( - ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload ) _ = load_model( - ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True + model_type="coarse", + use_gpu=coarse_use_gpu, + use_small=coarse_use_small, + force_reload=force_reload, ) _ = load_model( - ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload ) - _ = load_codec_model(use_gpu=use_gpu, force_reload=True) + _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) ####