Merge remote-tracking branch 'upstream/main'

This commit is contained in:
Francis LaBounty
2023-04-22 16:28:02 -06:00
7 changed files with 223 additions and 73 deletions

1
.gitignore vendored
View File

@@ -1,2 +1 @@
__pycache__/ __pycache__/

View File

@@ -32,14 +32,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
## 🤖 Usage ## 🤖 Usage
```python ```python
from bark import SAMPLE_RATE, generate_audio from bark import SAMPLE_RATE, generate_audio, preload_models
from IPython.display import Audio from IPython.display import Audio
# download and load all models
preload_models()
# generate audio from text
text_prompt = """ text_prompt = """
Hello, my name is Suno. And, uh — and I like pizza. [laughs] Hello, my name is Suno. And, uh — and I like pizza. [laughs]
But I also have other interests such as playing tic tac toe. But I also have other interests such as playing tic tac toe.
""" """
audio_array = generate_audio(text_prompt) audio_array = generate_audio(text_prompt)
# play text in notebook
Audio(audio_array, rate=SAMPLE_RATE) Audio(audio_array, rate=SAMPLE_RATE)
``` ```
@@ -83,7 +89,7 @@ audio_array = generate_audio(text_prompt)
### 🎤 Voice Presets and Voice/Audio Cloning ### 🎤 Voice Presets and Voice/Audio Cloning
Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{number}`. Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{0-9}`.
```python ```python
text_prompt = """ text_prompt = """

View File

@@ -1,2 +1,2 @@
from .api import generate_audio, text_to_semantic, semantic_to_waveform from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
from .generation import SAMPLE_RATE, preload_models from .generation import SAMPLE_RATE, preload_models

View File

@@ -9,6 +9,7 @@ def text_to_semantic(
text: str, text: str,
history_prompt: Optional[str] = None, history_prompt: Optional[str] = None,
temp: float = 0.7, temp: float = 0.7,
silent: bool = False,
): ):
"""Generate semantic array from text. """Generate semantic array from text.
@@ -16,6 +17,7 @@ def text_to_semantic(
text: text to be turned into audio text: text to be turned into audio
history_prompt: history choice for audio cloning history_prompt: history choice for audio cloning
temp: generation temperature (1.0 more diverse, 0.0 more conservative) temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
Returns: Returns:
numpy semantic array to be fed into `semantic_to_waveform` numpy semantic array to be fed into `semantic_to_waveform`
@@ -24,6 +26,8 @@ def text_to_semantic(
text, text,
history_prompt=history_prompt, history_prompt=history_prompt,
temp=temp, temp=temp,
silent=silent,
use_kv_caching=True
) )
return x_semantic return x_semantic
@@ -32,6 +36,8 @@ def semantic_to_waveform(
semantic_tokens: np.ndarray, semantic_tokens: np.ndarray,
history_prompt: Optional[str] = None, history_prompt: Optional[str] = None,
temp: float = 0.7, temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
): ):
"""Generate audio array from semantic input. """Generate audio array from semantic input.
@@ -39,29 +45,51 @@ def semantic_to_waveform(
semantic_tokens: semantic token output from `text_to_semantic` semantic_tokens: semantic token output from `text_to_semantic`
history_prompt: history choice for audio cloning history_prompt: history choice for audio cloning
temp: generation temperature (1.0 more diverse, 0.0 more conservative) temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
output_full: return full generation to be used as a history prompt
Returns: Returns:
numpy audio array at sample frequency 24khz numpy audio array at sample frequency 24khz
""" """
x_coarse_gen = generate_coarse( coarse_tokens = generate_coarse(
semantic_tokens, semantic_tokens,
history_prompt=history_prompt, history_prompt=history_prompt,
temp=temp, temp=temp,
silent=silent,
use_kv_caching=True
) )
x_fine_gen = generate_fine( fine_tokens = generate_fine(
x_coarse_gen, coarse_tokens,
history_prompt=history_prompt, history_prompt=history_prompt,
temp=0.5, temp=0.5,
) )
audio_arr = codec_decode(x_fine_gen) audio_arr = codec_decode(fine_tokens)
if output_full:
full_generation = {
"semantic_prompt": semantic_tokens,
"coarse_prompt": coarse_tokens,
"fine_prompt": fine_tokens,
}
return full_generation, audio_arr
return audio_arr return audio_arr
def save_as_prompt(filepath, full_generation):
assert(filepath.endswith(".npz"))
assert(isinstance(full_generation, dict))
assert("semantic_prompt" in full_generation)
assert("coarse_prompt" in full_generation)
assert("fine_prompt" in full_generation)
np.savez(filepath, **full_generation)
def generate_audio( def generate_audio(
text: str, text: str,
history_prompt: Optional[str] = None, history_prompt: Optional[str] = None,
text_temp: float = 0.7, text_temp: float = 0.7,
waveform_temp: float = 0.7, waveform_temp: float = 0.7,
silent: bool = False,
output_full: bool = False,
): ):
"""Generate audio array from input text. """Generate audio array from input text.
@@ -70,10 +98,28 @@ def generate_audio(
history_prompt: history choice for audio cloning history_prompt: history choice for audio cloning
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative) text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative) waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
silent: disable progress bar
output_full: return full generation to be used as a history prompt
Returns: Returns:
numpy audio array at sample frequency 24khz numpy audio array at sample frequency 24khz
""" """
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp) semantic_tokens = text_to_semantic(
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp) text,
history_prompt=history_prompt,
temp=text_temp,
silent=silent,
)
out = semantic_to_waveform(
semantic_tokens,
history_prompt=history_prompt,
temp=waveform_temp,
silent=silent,
output_full=output_full,
)
if output_full:
full_generation, audio_arr = out
return full_generation, audio_arr
else:
audio_arr = out
return audio_arr return audio_arr

View File

@@ -1,4 +1,5 @@
import contextlib import contextlib
import gc
import hashlib import hashlib
import os import os
import re import re
@@ -21,6 +22,7 @@ if (
torch.cuda.is_available() and torch.cuda.is_available() and
hasattr(torch.cuda, "amp") and hasattr(torch.cuda, "amp") and
hasattr(torch.cuda.amp, "autocast") and hasattr(torch.cuda.amp, "autocast") and
hasattr(torch.cuda, "is_bf16_supported") and
torch.cuda.is_bf16_supported() torch.cuda.is_bf16_supported()
): ):
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16) autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
@@ -58,20 +60,33 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/" REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
REMOTE_MODEL_PATHS = { REMOTE_MODEL_PATHS = {
"text_small": {
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
},
"coarse_small": {
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
},
"fine_small": {
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
"checksum": "5428d1befe05be2ba32195496e58dc90",
},
"text": { "text": {
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")), "path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
"checksum": "54afa89d65e318d4f5f80e8e8799026a", "checksum": "54afa89d65e318d4f5f80e8e8799026a",
}, },
"coarse": { "coarse": {
"path": os.environ.get( "path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
),
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28", "checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
}, },
"fine": { "fine": {
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")), "path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
"checksum": "59d184ed44e3650774a2f0503a48a97b", "checksum": "59d184ed44e3650774a2f0503a48a97b",
}, },
} }
@@ -98,8 +113,9 @@ def _md5(fname):
return hash_md5.hexdigest() return hash_md5.hexdigest()
def _get_ckpt_path(model_type): def _get_ckpt_path(model_type, use_small=False):
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"]) model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
return os.path.join(CACHE_DIR, f"{model_name}.pt") return os.path.join(CACHE_DIR, f"{model_name}.pt")
@@ -115,9 +131,9 @@ def _parse_s3_filepath(s3_filepath):
def _download(from_s3_path, to_local_path): def _download(from_s3_path, to_local_path):
os.makedirs(CACHE_DIR, exist_ok=True) os.makedirs(CACHE_DIR, exist_ok=True)
response = requests.get(from_s3_path, stream=True) response = requests.get(from_s3_path, stream=True)
total_size_in_bytes = int(response.headers.get('content-length', 0)) total_size_in_bytes = int(response.headers.get("content-length", 0))
block_size = 1024 # 1 Kibibyte block_size = 1024
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
with open(to_local_path, "wb") as file: with open(to_local_path, "wb") as file:
for data in response.iter_content(block_size): for data in response.iter_content(block_size):
progress_bar.update(len(data)) progress_bar.update(len(data))
@@ -165,11 +181,12 @@ def clean_models(model_key=None):
if k in models: if k in models:
del models[k] del models[k]
_clear_cuda_cache() _clear_cuda_cache()
gc.collect()
def _load_model(ckpt_path, device, model_type="text"): def _load_model(ckpt_path, device, use_small=False, model_type="text"):
if "cuda" not in device: if "cuda" not in device:
logger.warning("No GPU being used. Careful, Inference might be extremely slow!") logger.warning("No GPU being used. Careful, inference might be extremely slow!")
if model_type == "text": if model_type == "text":
ConfigClass = GPTConfig ConfigClass = GPTConfig
ModelClass = GPT ModelClass = GPT
@@ -181,15 +198,17 @@ def _load_model(ckpt_path, device, model_type="text"):
ModelClass = FineGPT ModelClass = FineGPT
else: else:
raise NotImplementedError() raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if ( if (
os.path.exists(ckpt_path) and os.path.exists(ckpt_path) and
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"] _md5(ckpt_path) != model_info["checksum"]
): ):
logger.warning(f"found outdated {model_type} model, removing...") logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path) os.remove(ckpt_path)
if not os.path.exists(ckpt_path): if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading...") logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path) _download(model_info["path"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device) checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack # this is a hack
model_args = checkpoint["model_args"] model_args = checkpoint["model_args"]
@@ -239,8 +258,8 @@ def _load_codec_model(device):
return model return model
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"): def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
_load_model_f = funcy.partial(_load_model, model_type=model_type) _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
if model_type not in ("text", "coarse", "fine"): if model_type not in ("text", "coarse", "fine"):
raise NotImplementedError() raise NotImplementedError()
global models global models
@@ -250,8 +269,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
device = "cuda" device = "cuda"
model_key = str(device) + f"__{model_type}" model_key = str(device) + f"__{model_type}"
if model_key not in models or force_reload: if model_key not in models or force_reload:
if ckpt_path is None: ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
ckpt_path = _get_ckpt_path(model_type)
clean_models(model_key=model_key) clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device) model = _load_model_f(ckpt_path, device)
models[model_key] = model models[model_key] = model
@@ -272,17 +290,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
return models[model_key] return models[model_key]
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True): def preload_models(
text_use_gpu=True,
text_use_small=False,
coarse_use_gpu=True,
coarse_use_small=False,
fine_use_gpu=True,
fine_use_small=False,
codec_use_gpu=True,
force_reload=False,
):
_ = load_model( _ = load_model(
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
) )
_ = load_model( _ = load_model(
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
) )
_ = load_model( _ = load_model(
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
) )
_ = load_codec_model(use_gpu=use_gpu, force_reload=True) _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
#### ####
@@ -320,12 +350,16 @@ def generate_text_semantic(
max_gen_duration_s=None, max_gen_duration_s=None,
allow_early_stop=True, allow_early_stop=True,
model=None, model=None,
use_kv_caching=False
): ):
"""Generate semantic tokens from text.""" """Generate semantic tokens from text."""
assert isinstance(text, str) assert isinstance(text, str)
text = _normalize_whitespace(text) text = _normalize_whitespace(text)
assert len(text.strip()) > 0 assert len(text.strip()) > 0
if history_prompt is not None: if history_prompt is not None:
if history_prompt.endswith(".npz"):
semantic_history = np.load(history_prompt)["semantic_prompt"]
else:
semantic_history = np.load( semantic_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["semantic_prompt"] )["semantic_prompt"]
@@ -377,8 +411,14 @@ def generate_text_semantic(
pbar = tqdm.tqdm(disable=silent, total=100) pbar = tqdm.tqdm(disable=silent, total=100)
pbar_state = 0 pbar_state = 0
tot_generated_duration_s = 0 tot_generated_duration_s = 0
kv_cache = None
for n in range(n_tot_steps): for n in range(n_tot_steps):
logits = model(x, merge_context=True) if use_kv_caching and kv_cache is not None:
x_input = x[:, [-1]]
else:
x_input = x
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE] relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
if allow_early_stop: if allow_early_stop:
relevant_logits = torch.hstack( relevant_logits = torch.hstack(
@@ -455,6 +495,7 @@ def generate_coarse(
max_coarse_history=630, # min 60 (faster), max 630 (more context) max_coarse_history=630, # min 60 (faster), max 630 (more context)
sliding_window_len=60, sliding_window_len=60,
model=None, model=None,
use_kv_caching=False
): ):
"""Generate coarse audio codes from semantic tokens.""" """Generate coarse audio codes from semantic tokens."""
assert ( assert (
@@ -469,6 +510,9 @@ def generate_coarse(
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio)) max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
if history_prompt is not None: if history_prompt is not None:
if history_prompt.endswith(".npz"):
x_history = np.load(history_prompt)
else:
x_history = np.load( x_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
) )
@@ -545,11 +589,18 @@ def generate_coarse(
x_coarse_in[:, -max_coarse_history:], x_coarse_in[:, -max_coarse_history:],
] ]
) )
kv_cache = None
for _ in range(sliding_window_len): for _ in range(sliding_window_len):
if n_step >= n_steps: if n_step >= n_steps:
continue continue
is_major_step = n_step % N_COARSE_CODEBOOKS == 0 is_major_step = n_step % N_COARSE_CODEBOOKS == 0
logits = model(x_in)
if use_kv_caching and kv_cache is not None:
x_input = x_in[:, [-1]]
else:
x_input = x_in
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
logit_start_idx = ( logit_start_idx = (
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
) )
@@ -611,6 +662,9 @@ def generate_fine(
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1 and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
) )
if history_prompt is not None: if history_prompt is not None:
if history_prompt.endswith(".npz"):
x_fine_history = np.load(history_prompt)["fine_prompt"]
else:
x_fine_history = np.load( x_fine_history = np.load(
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz") os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
)["fine_prompt"] )["fine_prompt"]

View File

@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
.view(1, 1, config.block_size, config.block_size)) .view(1, 1, config.block_size, config.block_size))
def forward(self, x): def forward(self, x, past_kv=None, use_cache=False):
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
# calculate query, key, values for all heads in batch and move head forward to be the batch dim # calculate query, key, values for all heads in batch and move head forward to be the batch dim
@@ -52,14 +52,36 @@ class CausalSelfAttention(nn.Module):
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
if past_kv is not None:
past_key = past_kv[0]
past_value = past_kv[1]
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)
FULL_T = k.shape[-2]
if use_cache is True:
present = (k, v)
else:
present = None
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
if self.flash: if self.flash:
# efficient attention using Flash Attention CUDA kernels # efficient attention using Flash Attention CUDA kernels
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) if past_kv is not None:
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
# to work around this we set is_causal=False.
is_causal = False
else:
is_causal = True
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
else: else:
# manual implementation of attention # manual implementation of attention
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
att = F.softmax(att, dim=-1) att = F.softmax(att, dim=-1)
att = self.attn_dropout(att) att = self.attn_dropout(att)
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
@@ -67,7 +89,7 @@ class CausalSelfAttention(nn.Module):
# output projection # output projection
y = self.resid_dropout(self.c_proj(y)) y = self.resid_dropout(self.c_proj(y))
return y return (y, present)
class MLP(nn.Module): class MLP(nn.Module):
@@ -95,10 +117,11 @@ class Block(nn.Module):
self.mlp = MLP(config) self.mlp = MLP(config)
self.layer_idx = layer_idx self.layer_idx = layer_idx
def forward(self, x): def forward(self, x, past_kv=None, use_cache=False):
x = x + self.attn(self.ln_1(x)) attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
x = x + attn_output
x = x + self.mlp(self.ln_2(x)) x = x + self.mlp(self.ln_2(x))
return x return (x, prev_kvs)
@dataclass @dataclass
class GPTConfig: class GPTConfig:
@@ -142,9 +165,13 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel() n_params -= self.transformer.wpe.weight.numel()
return n_params return n_params
def forward(self, idx, merge_context=False): def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
device = idx.device device = idx.device
b, t = idx.size() b, t = idx.size()
if past_kv is not None:
assert t == 1
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
else:
if merge_context: if merge_context:
assert(idx.shape[1] >= 256+256+1) assert(idx.shape[1] >= 256+256+1)
t = idx.shape[1] - 256 t = idx.shape[1] - 256
@@ -160,15 +187,33 @@ class GPT(nn.Module):
else: else:
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) if past_kv is None:
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) past_length = 0
past_kv = tuple([None] * len(self.transformer.h))
else:
past_length = past_kv[0][0].size(-2)
if position_ids is None:
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
position_ids = position_ids.unsqueeze(0) # shape (1, t)
assert position_ids.shape == (1, t)
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
x = self.transformer.drop(tok_emb + pos_emb) x = self.transformer.drop(tok_emb + pos_emb)
for block in self.transformer.h:
x = block(x) new_kv = () if use_cache else None
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
if use_cache:
new_kv = new_kv + (kv,)
x = self.transformer.ln_f(x) x = self.transformer.ln_f(x)
# inference-time mini-optimization: only forward the lm_head on the very last position # inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
return logits return (logits, new_kv)

View File

@@ -8,7 +8,7 @@ The following is additional information about the models released here.
Bark is a series of three transformer models that turn text into audio. Bark is a series of three transformer models that turn text into audio.
### Text to semantic tokens ### Text to semantic tokens
- Input: text, tokenized with [BERT tokenizer from huggingface](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer) - Input: text, tokenized with [BERT tokenizer from Hugging Face](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
- Output: semantic tokens that encode the audio to be generated - Output: semantic tokens that encode the audio to be generated
### Semantic to coarse tokens ### Semantic to coarse tokens