mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2026-04-03 09:46:24 +02:00
Merge remote-tracking branch 'upstream/main'
This commit is contained in:
@@ -1,2 +1,2 @@
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
|
||||
from .generation import SAMPLE_RATE, preload_models
|
||||
|
||||
58
bark/api.py
58
bark/api.py
@@ -9,6 +9,7 @@ def text_to_semantic(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
):
|
||||
"""Generate semantic array from text.
|
||||
|
||||
@@ -16,6 +17,7 @@ def text_to_semantic(
|
||||
text: text to be turned into audio
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
|
||||
Returns:
|
||||
numpy semantic array to be fed into `semantic_to_waveform`
|
||||
@@ -24,6 +26,8 @@ def text_to_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=True
|
||||
)
|
||||
return x_semantic
|
||||
|
||||
@@ -32,6 +36,8 @@ def semantic_to_waveform(
|
||||
semantic_tokens: np.ndarray,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
):
|
||||
"""Generate audio array from semantic input.
|
||||
|
||||
@@ -39,29 +45,51 @@ def semantic_to_waveform(
|
||||
semantic_tokens: semantic token output from `text_to_semantic`
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
output_full: return full generation to be used as a history prompt
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_coarse_gen = generate_coarse(
|
||||
coarse_tokens = generate_coarse(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
silent=silent,
|
||||
use_kv_caching=True
|
||||
)
|
||||
x_fine_gen = generate_fine(
|
||||
x_coarse_gen,
|
||||
fine_tokens = generate_fine(
|
||||
coarse_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=0.5,
|
||||
)
|
||||
audio_arr = codec_decode(x_fine_gen)
|
||||
audio_arr = codec_decode(fine_tokens)
|
||||
if output_full:
|
||||
full_generation = {
|
||||
"semantic_prompt": semantic_tokens,
|
||||
"coarse_prompt": coarse_tokens,
|
||||
"fine_prompt": fine_tokens,
|
||||
}
|
||||
return full_generation, audio_arr
|
||||
return audio_arr
|
||||
|
||||
|
||||
def save_as_prompt(filepath, full_generation):
|
||||
assert(filepath.endswith(".npz"))
|
||||
assert(isinstance(full_generation, dict))
|
||||
assert("semantic_prompt" in full_generation)
|
||||
assert("coarse_prompt" in full_generation)
|
||||
assert("fine_prompt" in full_generation)
|
||||
np.savez(filepath, **full_generation)
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
text_temp: float = 0.7,
|
||||
waveform_temp: float = 0.7,
|
||||
silent: bool = False,
|
||||
output_full: bool = False,
|
||||
):
|
||||
"""Generate audio array from input text.
|
||||
|
||||
@@ -70,10 +98,28 @@ def generate_audio(
|
||||
history_prompt: history choice for audio cloning
|
||||
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
silent: disable progress bar
|
||||
output_full: return full generation to be used as a history prompt
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp)
|
||||
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp)
|
||||
semantic_tokens = text_to_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=text_temp,
|
||||
silent=silent,
|
||||
)
|
||||
out = semantic_to_waveform(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=waveform_temp,
|
||||
silent=silent,
|
||||
output_full=output_full,
|
||||
)
|
||||
if output_full:
|
||||
full_generation, audio_arr = out
|
||||
return full_generation, audio_arr
|
||||
else:
|
||||
audio_arr = out
|
||||
return audio_arr
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import contextlib
|
||||
import gc
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
@@ -21,6 +22,7 @@ if (
|
||||
torch.cuda.is_available() and
|
||||
hasattr(torch.cuda, "amp") and
|
||||
hasattr(torch.cuda.amp, "autocast") and
|
||||
hasattr(torch.cuda, "is_bf16_supported") and
|
||||
torch.cuda.is_bf16_supported()
|
||||
):
|
||||
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
|
||||
@@ -58,20 +60,33 @@ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
|
||||
|
||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||
|
||||
REMOTE_BASE_URL = "https://dl.suno-models.io/bark/models/v0/"
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text.pt"),
|
||||
"checksum": "b3e42bcbab23b688355cd44128c4cdd3",
|
||||
},
|
||||
"coarse_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse.pt"),
|
||||
"checksum": "5fe964825e3b0321f9d5f3857b89194d",
|
||||
},
|
||||
"fine_small": {
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine.pt"),
|
||||
"checksum": "5428d1befe05be2ba32195496e58dc90",
|
||||
},
|
||||
"text": {
|
||||
"path": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text_2.pt")),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "text_2.pt"),
|
||||
"checksum": "54afa89d65e318d4f5f80e8e8799026a",
|
||||
},
|
||||
"coarse": {
|
||||
"path": os.environ.get(
|
||||
"SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse_2.pt")
|
||||
),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "coarse_2.pt"),
|
||||
"checksum": "8a98094e5e3a255a5c9c0ab7efe8fd28",
|
||||
},
|
||||
"fine": {
|
||||
"path": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine_2.pt")),
|
||||
"path": os.path.join(REMOTE_BASE_URL, "fine_2.pt"),
|
||||
"checksum": "59d184ed44e3650774a2f0503a48a97b",
|
||||
},
|
||||
}
|
||||
@@ -98,8 +113,9 @@ def _md5(fname):
|
||||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type):
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_type]["path"])
|
||||
def _get_ckpt_path(model_type, use_small=False):
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_name = _string_md5(REMOTE_MODEL_PATHS[model_key]["path"])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
|
||||
|
||||
@@ -115,9 +131,9 @@ def _parse_s3_filepath(s3_filepath):
|
||||
def _download(from_s3_path, to_local_path):
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
response = requests.get(from_s3_path, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
||||
block_size = 1024
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True)
|
||||
with open(to_local_path, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
@@ -165,11 +181,12 @@ def clean_models(model_key=None):
|
||||
if k in models:
|
||||
del models[k]
|
||||
_clear_cuda_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, model_type="text"):
|
||||
def _load_model(ckpt_path, device, use_small=False, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
|
||||
logger.warning("No GPU being used. Careful, inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
@@ -181,15 +198,17 @@ def _load_model(ckpt_path, device, model_type="text"):
|
||||
ModelClass = FineGPT
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
|
||||
model_info = REMOTE_MODEL_PATHS[model_key]
|
||||
if (
|
||||
os.path.exists(ckpt_path) and
|
||||
_md5(ckpt_path) != REMOTE_MODEL_PATHS[model_type]["checksum"]
|
||||
_md5(ckpt_path) != model_info["checksum"]
|
||||
):
|
||||
logger.warning(f"found outdated {model_type} model, removing...")
|
||||
logger.warning(f"found outdated {model_type} model, removing.")
|
||||
os.remove(ckpt_path)
|
||||
if not os.path.exists(ckpt_path):
|
||||
logger.info(f"{model_type} model not found, downloading...")
|
||||
_download(REMOTE_MODEL_PATHS[model_type]["path"], ckpt_path)
|
||||
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
|
||||
_download(model_info["path"], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
@@ -239,8 +258,8 @@ def _load_codec_model(device):
|
||||
return model
|
||||
|
||||
|
||||
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type)
|
||||
def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
global models
|
||||
@@ -250,8 +269,7 @@ def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="tex
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__{model_type}"
|
||||
if model_key not in models or force_reload:
|
||||
if ckpt_path is None:
|
||||
ckpt_path = _get_ckpt_path(model_type)
|
||||
ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
@@ -272,17 +290,29 @@ def load_codec_model(use_gpu=True, force_reload=False):
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
|
||||
def preload_models(
|
||||
text_use_gpu=True,
|
||||
text_use_small=False,
|
||||
coarse_use_gpu=True,
|
||||
coarse_use_small=False,
|
||||
fine_use_gpu=True,
|
||||
fine_use_small=False,
|
||||
codec_use_gpu=True,
|
||||
force_reload=False,
|
||||
):
|
||||
_ = load_model(
|
||||
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
|
||||
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
|
||||
model_type="coarse",
|
||||
use_gpu=coarse_use_gpu,
|
||||
use_small=coarse_use_small,
|
||||
force_reload=force_reload,
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
|
||||
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
|
||||
)
|
||||
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
|
||||
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
|
||||
|
||||
|
||||
####
|
||||
@@ -320,15 +350,19 @@ def generate_text_semantic(
|
||||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
text = _normalize_whitespace(text)
|
||||
assert len(text.strip()) > 0
|
||||
if history_prompt is not None:
|
||||
semantic_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["semantic_prompt"]
|
||||
if history_prompt.endswith(".npz"):
|
||||
semantic_history = np.load(history_prompt)["semantic_prompt"]
|
||||
else:
|
||||
semantic_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["semantic_prompt"]
|
||||
assert (
|
||||
isinstance(semantic_history, np.ndarray)
|
||||
and len(semantic_history.shape) == 1
|
||||
@@ -377,8 +411,14 @@ def generate_text_semantic(
|
||||
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||
pbar_state = 0
|
||||
tot_generated_duration_s = 0
|
||||
kv_cache = None
|
||||
for n in range(n_tot_steps):
|
||||
logits = model(x, merge_context=True)
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x[:, [-1]]
|
||||
else:
|
||||
x_input = x
|
||||
|
||||
logits, kv_cache = model(x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
@@ -455,6 +495,7 @@ def generate_coarse(
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
use_kv_caching=False
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
@@ -469,9 +510,12 @@ def generate_coarse(
|
||||
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
|
||||
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
||||
if history_prompt is not None:
|
||||
x_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)
|
||||
if history_prompt.endswith(".npz"):
|
||||
x_history = np.load(history_prompt)
|
||||
else:
|
||||
x_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)
|
||||
x_semantic_history = x_history["semantic_prompt"]
|
||||
x_coarse_history = x_history["coarse_prompt"]
|
||||
assert (
|
||||
@@ -545,11 +589,18 @@ def generate_coarse(
|
||||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
kv_cache = None
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
logits = model(x_in)
|
||||
|
||||
if use_kv_caching and kv_cache is not None:
|
||||
x_input = x_in[:, [-1]]
|
||||
else:
|
||||
x_input = x_in
|
||||
|
||||
logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
@@ -611,9 +662,12 @@ def generate_fine(
|
||||
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
|
||||
)
|
||||
if history_prompt is not None:
|
||||
x_fine_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["fine_prompt"]
|
||||
if history_prompt.endswith(".npz"):
|
||||
x_fine_history = np.load(history_prompt)["fine_prompt"]
|
||||
else:
|
||||
x_fine_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["fine_prompt"]
|
||||
assert (
|
||||
isinstance(x_fine_history, np.ndarray)
|
||||
and len(x_fine_history.shape) == 2
|
||||
|
||||
@@ -43,7 +43,7 @@ class CausalSelfAttention(nn.Module):
|
||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
@@ -52,14 +52,36 @@ class CausalSelfAttention(nn.Module):
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
if past_kv is not None:
|
||||
past_key = past_kv[0]
|
||||
past_value = past_kv[1]
|
||||
k = torch.cat((past_key, k), dim=-2)
|
||||
v = torch.cat((past_value, v), dim=-2)
|
||||
|
||||
FULL_T = k.shape[-2]
|
||||
|
||||
if use_cache is True:
|
||||
present = (k, v)
|
||||
else:
|
||||
present = None
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||
if past_kv is not None:
|
||||
# When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
|
||||
# the query for the last token. scaled_dot_product_attention interprets this as the first token in the
|
||||
# sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
|
||||
# to work around this we set is_causal=False.
|
||||
is_causal = False
|
||||
else:
|
||||
is_causal = True
|
||||
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
|
||||
else:
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
@@ -67,7 +89,7 @@ class CausalSelfAttention(nn.Module):
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
return (y, present)
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
@@ -95,10 +117,11 @@ class Block(nn.Module):
|
||||
self.mlp = MLP(config)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
def forward(self, x, past_kv=None, use_cache=False):
|
||||
attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
|
||||
x = x + attn_output
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
return (x, prev_kvs)
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
@@ -142,33 +165,55 @@ class GPT(nn.Module):
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False):
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
if past_kv is not None:
|
||||
assert t == 1
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
else:
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
|
||||
if past_kv is None:
|
||||
past_length = 0
|
||||
past_kv = tuple([None] * len(self.transformer.h))
|
||||
else:
|
||||
past_length = past_kv[0][0].size(-2)
|
||||
|
||||
if position_ids is None:
|
||||
position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
|
||||
position_ids = position_ids.unsqueeze(0) # shape (1, t)
|
||||
assert position_ids.shape == (1, t)
|
||||
|
||||
pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
|
||||
new_kv = () if use_cache else None
|
||||
|
||||
for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
|
||||
x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
|
||||
|
||||
if use_cache:
|
||||
new_kv = new_kv + (kv,)
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
return logits
|
||||
return (logits, new_kv)
|
||||
|
||||
Reference in New Issue
Block a user