mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 03:07:58 +01:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,2 +1 @@
|
||||
__pycache__/
|
||||
|
||||
|
||||
10
README.md
10
README.md
@@ -32,14 +32,20 @@ Bark is a transformer-based text-to-audio model created by [Suno](https://suno.a
|
||||
## 🤖 Usage
|
||||
|
||||
```python
|
||||
from bark import SAMPLE_RATE, generate_audio
|
||||
from bark import SAMPLE_RATE, generate_audio, preload_models
|
||||
from IPython.display import Audio
|
||||
|
||||
# download and load all models
|
||||
preload_models()
|
||||
|
||||
# generate audio from text
|
||||
text_prompt = """
|
||||
Hello, my name is Suno. And, uh — and I like pizza. [laughs]
|
||||
But I also have other interests such as playing tic tac toe.
|
||||
"""
|
||||
audio_array = generate_audio(text_prompt)
|
||||
|
||||
# play text in notebook
|
||||
Audio(audio_array, rate=SAMPLE_RATE)
|
||||
```
|
||||
|
||||
@@ -83,7 +89,7 @@ audio_array = generate_audio(text_prompt)
|
||||
|
||||
### 🎤 Voice Presets and Voice/Audio Cloning
|
||||
|
||||
Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{number}`.
|
||||
Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{0-9}`.
|
||||
|
||||
```python
|
||||
text_prompt = """
|
||||
|
||||
@@ -1,2 +1,2 @@
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
|
||||
from .generation import SAMPLE_RATE, preload_models
|
||||
|
||||
58
bark/api.py
58
bark/api.py
@@ -9,6 +9,7 @@ def text_to_semantic(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
):
|
||||
"""Generate semantic array from text.
|
||||
|
||||
@@ -16,6 +17,7 @@ def text_to_semantic(
|
||||
text: text to be turned into audio
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
|
||||
Returns:
|
||||
numpy semantic array to be fed into `semantic_to_waveform`
|
||||
@@ -24,6 +26,8 @@ def text_to_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=True
|
||||
)
|
||||
return x_semantic
|
||||
|
||||
@@ -32,6 +36,8 @@ def semantic_to_waveform(
|
||||
semantic_tokens: np.ndarray,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
):
|
||||
"""Generate audio array from semantic input.
|
||||
|
||||
@@ -39,29 +45,51 @@ def semantic_to_waveform(
|
||||
semantic_tokens: semantic token output from `text_to_semantic`
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
output_full: return full generation to be used as a history prompt
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_coarse_gen = generate_coarse(
|
||||
coarse_tokens = generate_coarse(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=True
|
||||
)
|
||||
x_fine_gen = generate_fine(
|
||||
x_coarse_gen,
|
||||
fine_tokens = generate_fine(
|
||||
coarse_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=0.5,
|
||||
)
|
||||
audio_arr = codec_decode(x_fine_gen)
|
||||
audio_arr = codec_decode(fine_tokens)
|
||||
if output_full:
|
||||
full_generation = {
|
||||
"semantic_prompt": semantic_tokens,
|
||||
"coarse_prompt": coarse_tokens,
|
||||
"fine_prompt": fine_tokens,
|
||||
}
|
||||
return full_generation, audio_arr
|
||||
return audio_arr
|
||||
|
||||
|
||||
def save_as_prompt(filepath, full_generation):
|
||||
assert(filepath.endswith(".npz"))
|
||||
assert(isinstance(full_generation, dict))
|
||||
assert("semantic_prompt" in full_generation)
|
||||
assert("coarse_prompt" in full_generation)
|
||||
assert("fine_prompt" in full_generation)
|
||||
np.savez(filepath, **full_generation)
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
text_temp: float = 0.7,
|
||||
waveform_temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
):
|
||||
"""Generate audio array from input text.
|
||||
|
||||
@@ -70,10 +98,28 @@ def generate_audio(
|
||||
history_prompt: history choice for audio cloning
|
||||
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
output_full: return full generation to be used as a history prompt
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp)
|
||||
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp)
|
||||
semantic_tokens = text_to_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=text_temp,
|
||||
silent=silent,
|
||||
)
|
||||
out = semantic_to_waveform(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=waveform_temp,
|
||||
silent=silent,
|
||||
output_full=output_full,
|
||||
)
|
||||
if output_full:
|
||||
full_generation, audio_arr = out
|
||||
return full_generation, audio_arr
|
||||
else:
|
||||
audio_arr = out
|
||||
return audio_arr
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
@@ -21,6 +22,7 @@ if (
|
||||
torch.cuda.is_available() and
|
||||
hasattr(torch.cuda, "amp") and
|
||||
hasattr(torch.cuda.amp, "autocast") and
|
||||
hasattr(torch.cuda, "is_bf16_supported") and
|
||||
torch.cuda.is_bf16_supported()
|
||||
):
|
||||
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
|
||||
@@ -58,20 +60,33 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||
|
||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||
},
|
||||
"coarse_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||
},
|
||||
"fine_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||
},
|
||||
"text": {
|
||||
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||
},
|
||||
"coarse": {
|
||||
"path": os.environ.get(
|
||||
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
|
||||
),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||
},
|
||||
"fine": {
|
||||
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||
},
|
||||
}
|
||||
@@ -98,8 +113,9 @@ def _md5(fname):
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type):
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
|
||||
|
||||
@@ -115,9 +131,9 @@ def _parse_s3_filepath(s3_filepath):
|
||||
def _download(from_s3_path, to_local_path):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
response = requests.get(from_s3_path, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
with open(to_local_path, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
@@ -165,11 +181,12 @@ def clean_models(model_key=None):
|
||||
if k in models:
|
||||
del models[k]
|
||||
_clear_cuda_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, model_type="text"):
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
|
||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
@@ -181,15 +198,17 @@ def _load_model(ckpt_path, device, model_type="text"):
|
||||
ModelClass = FineGPT
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if (
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||
_md5(ckpt_path) != model_info["checksum"]
|
||||
):
|
||||
logger.warning(f"found outdated {model_type} model, removing...")
|
||||
logger.warning(f"found outdated {model_type} model, removing.")
|
||||
os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading...")
|
||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["path"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
@@ -239,8 +258,8 @@ def _load_codec_model(device):
|
||||
return model
|
||||
|
||||
|
||||
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type)
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
global models
|
||||
@@ -250,8 +269,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__{model_type}"
|
||||
if model_key not in models or force_reload:
|
||||
if ckpt_path is None:
|
||||
ckpt_path = _get_ckpt_path(model_type)
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
@@ -272,17 +290,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
|
||||
def preload_models(
|
||||
text_use_gpu=True,
|
||||
text_use_small=False,
|
||||
coarse_use_gpu=True,
|
||||
coarse_use_small=False,
|
||||
fine_use_gpu=True,
|
||||
fine_use_small=False,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
):
|
||||
_ = load_model(
|
||||
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
|
||||
|
||||
####
|
||||
@@ -320,15 +350,19 @@ def generate_text_semantic(
|
||||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
text = _normalize_whitespace(text)
|
||||
assert len(text.strip()) > 0
|
||||
if history_prompt is not None:
|
||||
semantic_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["semantic_prompt"]
|
||||
if history_prompt.endswith(".npz"):
|
||||
semantic_history = np.load(history_prompt)["semantic_prompt"]
|
||||
else:
|
||||
semantic_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["semantic_prompt"]
|
||||
assert (
|
||||
isinstance(semantic_history, np.ndarray)
|
||||
and len(semantic_history.shape) == 1
|
||||
@@ -377,8 +411,14 @@ def generate_text_semantic(
|
||||
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||
pbar_state = 0
|
||||
tot_generated_duration_s = 0
|
||||
kv_cache = None
|
||||
for n in range(n_tot_steps):
|
||||
logits = model(x, merge_context=True)
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
|
||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
@@ -455,6 +495,7 @@ def generate_coarse(
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
@@ -469,9 +510,12 @@ def generate_coarse(
|
||||
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
|
||||
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
||||
if history_prompt is not None:
|
||||
x_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)
|
||||
if history_prompt.endswith(".npz"):
|
||||
x_history = np.load(history_prompt)
|
||||
else:
|
||||
x_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)
|
||||
x_semantic_history = x_history["semantic_prompt"]
|
||||
x_coarse_history = x_history["coarse_prompt"]
|
||||
assert (
|
||||
@@ -545,11 +589,18 @@ def generate_coarse(
|
||||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
logits = model(x_in)
|
||||
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x_in[:, [-1]]
|
||||
else:
|
||||
x_input = x_in
|
||||
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
@@ -611,9 +662,12 @@ def generate_fine(
|
||||
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
|
||||
)
|
||||
if history_prompt is not None:
|
||||
x_fine_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["fine_prompt"]
|
||||
if history_prompt.endswith(".npz"):
|
||||
x_fine_history = np.load(history_prompt)["fine_prompt"]
|
||||
else:
|
||||
x_fine_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["fine_prompt"]
|
||||
assert (
|
||||
isinstance(x_fine_history, np.ndarray)
|
||||
and len(x_fine_history.shape) == 2
|
||||
|
||||
@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
|
||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
@@ -52,14 +52,36 @@ class CausalSelfAttention(nn.Module):
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
if past_kv is not None:
|
||||
past_key = past_kv[0]
|
||||
past_value = past_kv[1]
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
|
||||
FULL_T = k.shape[-2]
|
||||
|
||||
if use_cache is True:
|
||||
present = (k, v)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||
if past_kv is not None:
|
||||
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
|
||||
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
|
||||
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
|
||||
# to work around this we set is_causal=False.
|
||||
is_causal = False
|
||||
else:
|
||||
is_causal = True
|
||||
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
|
||||
else:
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
@@ -67,7 +89,7 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
return (y, present)
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
@@ -95,10 +117,11 @@ class Block(nn.Module):
|
||||
self.mlp = MLP(config)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
|
||||
x = x + attn_output
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
return (x, prev_kvs)
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
@@ -142,33 +165,55 @@ class GPT(nn.Module):
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False):
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
if past_kv is not None:
|
||||
assert t == 1
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
else:
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
|
||||
if past_kv is None:
|
||||
past_length = 0
|
||||
past_kv = tuple([None] * len(self.transformer.h))
|
||||
else:
|
||||
past_length = past_kv[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0) # shape (1, t)
|
||||
assert position_ids.shape == (1, t)
|
||||
|
||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
||||
new_kv = () if use_cache else None
|
||||
|
||||
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
||||
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
|
||||
|
||||
if use_cache:
|
||||
new_kv = new_kv + (kv,)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
return logits
|
||||
return (logits, new_kv)
|
||||
|
||||
@@ -8,7 +8,7 @@ The following is additional information about the models released here.
|
||||
|
||||
Bark is a series of three transformer models that turn text into audio.
|
||||
### Text to semantic tokens
|
||||
- Input: text, tokenized with [BERT tokenizer from huggingface](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
|
||||
- Input: text, tokenized with [BERT tokenizer from Hugging Face](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
|
||||
- Output: semantic tokens that encode the audio to be generated
|
||||
|
||||
### Semantic to coarse tokens
|
||||
|
||||
Reference in New Issue
Block a user