mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2026-04-03 09:46:24 +02:00
first commit
This commit is contained in:
2
bark/__init__.py
Normal file
2
bark/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
from .api import generate_audio, text_to_semantic, semantic_to_waveform
|
||||
from .generation import SAMPLE_RATE, preload_models
|
||||
BIN
bark/__pycache__/__init__.cpython-38.pyc
Normal file
BIN
bark/__pycache__/__init__.cpython-38.pyc
Normal file
Binary file not shown.
BIN
bark/__pycache__/api.cpython-38.pyc
Normal file
BIN
bark/__pycache__/api.cpython-38.pyc
Normal file
Binary file not shown.
BIN
bark/__pycache__/generation.cpython-38.pyc
Normal file
BIN
bark/__pycache__/generation.cpython-38.pyc
Normal file
Binary file not shown.
BIN
bark/__pycache__/model.cpython-38.pyc
Normal file
BIN
bark/__pycache__/model.cpython-38.pyc
Normal file
Binary file not shown.
BIN
bark/__pycache__/model_fine.cpython-38.pyc
Normal file
BIN
bark/__pycache__/model_fine.cpython-38.pyc
Normal file
Binary file not shown.
79
bark/api.py
Normal file
79
bark/api.py
Normal file
@@ -0,0 +1,79 @@
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
|
||||
|
||||
|
||||
def text_to_semantic(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
):
|
||||
"""Generate semantic array from text.
|
||||
|
||||
Args:
|
||||
text: text to be turned into audio
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
|
||||
Returns:
|
||||
numpy semantic array to be fed into `semantic_to_waveform`
|
||||
"""
|
||||
x_semantic = generate_text_semantic(
|
||||
text,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
)
|
||||
return x_semantic
|
||||
|
||||
|
||||
def semantic_to_waveform(
|
||||
semantic_tokens: np.ndarray,
|
||||
history_prompt: Optional[str] = None,
|
||||
temp: float = 0.7,
|
||||
):
|
||||
"""Generate audio array from semantic input.
|
||||
|
||||
Args:
|
||||
semantic_tokens: semantic token output from `text_to_semantic`
|
||||
history_prompt: history choice for audio cloning
|
||||
temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_coarse_gen = generate_coarse(
|
||||
semantic_tokens,
|
||||
history_prompt=history_prompt,
|
||||
temp=temp,
|
||||
)
|
||||
x_fine_gen = generate_fine(
|
||||
x_coarse_gen,
|
||||
history_prompt=history_prompt,
|
||||
temp=0.5,
|
||||
)
|
||||
audio_arr = codec_decode(x_fine_gen)
|
||||
return audio_arr
|
||||
|
||||
|
||||
def generate_audio(
|
||||
text: str,
|
||||
history_prompt: Optional[str] = None,
|
||||
text_temp: float = 0.7,
|
||||
waveform_temp: float = 0.7,
|
||||
):
|
||||
"""Generate audio array from input text.
|
||||
|
||||
Args:
|
||||
text: text to be turned into audio
|
||||
history_prompt: history choice for audio cloning
|
||||
text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
|
||||
|
||||
Returns:
|
||||
numpy audio array at sample frequency 24khz
|
||||
"""
|
||||
x_semantic = text_to_semantic(text, history_prompt=history_prompt, temp=text_temp)
|
||||
audio_arr = semantic_to_waveform(x_semantic, history_prompt=history_prompt, temp=waveform_temp)
|
||||
return audio_arr
|
||||
BIN
bark/assets/prompts/brylcream.npz
Normal file
BIN
bark/assets/prompts/brylcream.npz
Normal file
Binary file not shown.
BIN
bark/assets/prompts/es-woman.npz
Normal file
BIN
bark/assets/prompts/es-woman.npz
Normal file
Binary file not shown.
BIN
bark/assets/prompts/man-narrator.npz
Normal file
BIN
bark/assets/prompts/man-narrator.npz
Normal file
Binary file not shown.
693
bark/generation.py
Normal file
693
bark/generation.py
Normal file
@@ -0,0 +1,693 @@
|
||||
import contextlib
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import requests
|
||||
import sys
|
||||
|
||||
from encodec import EncodecModel
|
||||
import funcy
|
||||
import logging
|
||||
import numpy as np
|
||||
from scipy.special import softmax
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import tqdm
|
||||
from transformers import BertTokenizer
|
||||
|
||||
from .model import GPTConfig, GPT
|
||||
from .model_fine import FineGPT, FineGPTConfig
|
||||
|
||||
if (
|
||||
torch.cuda.is_available() and
|
||||
hasattr(torch.cuda, "amp") and
|
||||
hasattr(torch.cuda.amp, "autocast") and
|
||||
torch.cuda.is_bf16_supported()
|
||||
):
|
||||
autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
|
||||
else:
|
||||
@contextlib.contextmanager
|
||||
def autocast():
|
||||
yield
|
||||
|
||||
|
||||
# hold models in global scope to lazy load
|
||||
global models
|
||||
models = {}
|
||||
|
||||
|
||||
CONTEXT_WINDOW_SIZE = 1024
|
||||
|
||||
SEMANTIC_RATE_HZ = 49.9
|
||||
SEMANTIC_VOCAB_SIZE = 10_000
|
||||
|
||||
CODEBOOK_SIZE = 1024
|
||||
N_COARSE_CODEBOOKS = 2
|
||||
N_FINE_CODEBOOKS = 8
|
||||
COARSE_RATE_HZ = 75
|
||||
|
||||
SAMPLE_RATE = 24_000
|
||||
|
||||
|
||||
ALLOWED_PROMPTS = (
|
||||
"brylcream",
|
||||
"es-woman",
|
||||
"man-narrator",
|
||||
)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
os.makedirs(CACHE_DIR, exist_ok=True)
|
||||
|
||||
|
||||
REMOTE_BASE_URL = "http://s3.amazonaws.com/suno-public/bark/models/v0/"
|
||||
REMOTE_MODEL_PATHS = {
|
||||
"text": os.environ.get("SUNO_TEXT_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "text.pt")),
|
||||
"coarse": os.environ.get("SUNO_COARSE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "coarse.pt")),
|
||||
"fine": os.environ.get("SUNO_FINE_MODEL_PATH", os.path.join(REMOTE_BASE_URL, "fine.pt")),
|
||||
}
|
||||
|
||||
|
||||
def _compute_md5(s):
|
||||
m = hashlib.md5()
|
||||
m.update(s.encode("utf-8"))
|
||||
return m.hexdigest()
|
||||
|
||||
|
||||
def _get_ckpt_path(model_type):
|
||||
model_name = _compute_md5(REMOTE_MODEL_PATHS[model_type])
|
||||
return os.path.join(CACHE_DIR, f"{model_name}.pt")
|
||||
|
||||
|
||||
S3_BUCKET_PATH_RE = r"s3\:\/\/(.+?)\/"
|
||||
|
||||
|
||||
def _parse_s3_filepath(s3_filepath):
|
||||
bucket_name = re.search(S3_BUCKET_PATH_RE, s3_filepath).group(1)
|
||||
rel_s3_filepath = re.sub(S3_BUCKET_PATH_RE, "", s3_filepath)
|
||||
return bucket_name, rel_s3_filepath
|
||||
|
||||
|
||||
def _download(from_s3_path, to_local_path):
|
||||
response = requests.get(from_s3_path, stream=True)
|
||||
total_size_in_bytes = int(response.headers.get('content-length', 0))
|
||||
block_size = 1024 # 1 Kibibyte
|
||||
progress_bar = tqdm.tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True)
|
||||
with open(to_local_path, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
progress_bar.update(len(data))
|
||||
file.write(data)
|
||||
progress_bar.close()
|
||||
if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
|
||||
raise ValueError("ERROR, something went wrong")
|
||||
|
||||
|
||||
class InferenceContext:
|
||||
def __init__(self, benchmark=False):
|
||||
# we can't expect inputs to be the same length, so disable benchmarking by default
|
||||
self._chosen_cudnn_benchmark = benchmark
|
||||
self._cudnn_benchmark = None
|
||||
|
||||
def __enter__(self):
|
||||
self._cudnn_benchmark = torch.backends.cudnn.benchmark
|
||||
torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
|
||||
|
||||
def __exit__(self, exc_type, exc_value, exc_traceback):
|
||||
torch.backends.cudnn.benchmark = self._cudnn_benchmark
|
||||
|
||||
|
||||
if torch.cuda.is_available():
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _inference_mode():
|
||||
with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
|
||||
yield
|
||||
|
||||
|
||||
def _clear_cuda_cache():
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
||||
def clean_models(model_key=None):
|
||||
global models
|
||||
model_keys = [model_key] if model_key is not None else models.keys()
|
||||
for k in model_keys:
|
||||
if k in models:
|
||||
del models[k]
|
||||
_clear_cuda_cache()
|
||||
|
||||
|
||||
def _load_model(ckpt_path, device, model_type="text"):
|
||||
if "cuda" not in device:
|
||||
logger.warning("No GPU being used. Careful, Inference might be extremely slow!")
|
||||
if model_type == "text":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
elif model_type == "coarse":
|
||||
ConfigClass = GPTConfig
|
||||
ModelClass = GPT
|
||||
elif model_type == "fine":
|
||||
ConfigClass = FineGPTConfig
|
||||
ModelClass = FineGPT
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
if not os.path.exists(ckpt_path):
|
||||
print(f"{model_type} model not found, downloading...")
|
||||
_download(REMOTE_MODEL_PATHS[model_type], ckpt_path)
|
||||
checkpoint = torch.load(ckpt_path, map_location=device)
|
||||
# this is a hack
|
||||
model_args = checkpoint["model_args"]
|
||||
if "input_vocab_size" not in model_args:
|
||||
model_args["input_vocab_size"] = model_args["vocab_size"]
|
||||
model_args["output_vocab_size"] = model_args["vocab_size"]
|
||||
del model_args["vocab_size"]
|
||||
gptconf = ConfigClass(**checkpoint["model_args"])
|
||||
model = ModelClass(gptconf)
|
||||
state_dict = checkpoint["model"]
|
||||
# fixup checkpoint
|
||||
unwanted_prefix = "_orig_mod."
|
||||
for k, v in list(state_dict.items()):
|
||||
if k.startswith(unwanted_prefix):
|
||||
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
|
||||
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
|
||||
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
|
||||
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
||||
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
|
||||
if len(extra_keys) != 0:
|
||||
raise ValueError(f"extra keys found: {extra_keys}")
|
||||
if len(missing_keys) != 0:
|
||||
raise ValueError(f"missing keys: {missing_keys}")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
n_params = model.get_num_params()
|
||||
val_loss = checkpoint["best_val_loss"].item()
|
||||
print(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
|
||||
model.eval()
|
||||
model.to(device)
|
||||
del checkpoint, state_dict
|
||||
_clear_cuda_cache()
|
||||
if model_type == "text":
|
||||
tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
|
||||
return {
|
||||
"model": model,
|
||||
"tokenizer": tokenizer,
|
||||
}
|
||||
return model
|
||||
|
||||
|
||||
def _load_codec_model(device):
|
||||
model = EncodecModel.encodec_model_24khz()
|
||||
model.set_target_bandwidth(6.0)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
_clear_cuda_cache()
|
||||
return model
|
||||
|
||||
|
||||
def load_model(ckpt_path=None, use_gpu=True, force_reload=False, model_type="text"):
|
||||
_load_model_f = funcy.partial(_load_model, model_type=model_type)
|
||||
if model_type not in ("text", "coarse", "fine"):
|
||||
raise NotImplementedError()
|
||||
global models
|
||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__{model_type}"
|
||||
if model_key not in models or force_reload:
|
||||
if ckpt_path is None:
|
||||
ckpt_path = _get_ckpt_path(model_type)
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_model_f(ckpt_path, device)
|
||||
models[model_key] = model
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def load_codec_model(use_gpu=True, force_reload=False):
|
||||
global models
|
||||
if torch.cuda.device_count() == 0 or not use_gpu:
|
||||
device = "cpu"
|
||||
else:
|
||||
device = "cuda"
|
||||
model_key = str(device) + f"__codec"
|
||||
if model_key not in models or force_reload:
|
||||
clean_models(model_key=model_key)
|
||||
model = _load_codec_model(device)
|
||||
models[model_key] = model
|
||||
return models[model_key]
|
||||
|
||||
|
||||
def preload_models(text_ckpt_path=None, coarse_ckpt_path=None, fine_ckpt_path=None, use_gpu=True):
|
||||
_ = load_model(
|
||||
ckpt_path=text_ckpt_path, model_type="text", use_gpu=use_gpu, force_reload=True
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=coarse_ckpt_path, model_type="coarse", use_gpu=use_gpu, force_reload=True
|
||||
)
|
||||
_ = load_model(
|
||||
ckpt_path=fine_ckpt_path, model_type="fine", use_gpu=use_gpu, force_reload=True
|
||||
)
|
||||
_ = load_codec_model(use_gpu=use_gpu, force_reload=True)
|
||||
|
||||
|
||||
####
|
||||
# Generation Functionality
|
||||
####
|
||||
|
||||
|
||||
def _tokenize(tokenizer, text):
|
||||
return tokenizer.encode(text, add_special_tokens=False)
|
||||
|
||||
|
||||
def _detokenize(tokenizer, enc_text):
|
||||
return tokenizer.decode(enc_text)
|
||||
|
||||
|
||||
def _normalize_whitespace(text):
|
||||
return re.sub(r"\s+", " ", text).strip()
|
||||
|
||||
|
||||
TEXT_ENCODING_OFFSET = 10_048
|
||||
SEMANTIC_PAD_TOKEN = 10_000
|
||||
TEXT_PAD_TOKEN = 129_595
|
||||
SEMANTIC_INFER_TOKEN = 129_599
|
||||
|
||||
|
||||
def generate_text_semantic(
|
||||
text,
|
||||
history_prompt=None,
|
||||
temp=0.7,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
use_gpu=True,
|
||||
silent=False,
|
||||
min_eos_p=0.2,
|
||||
max_gen_duration_s=None,
|
||||
allow_early_stop=True,
|
||||
model=None,
|
||||
):
|
||||
"""Generate semantic tokens from text."""
|
||||
assert isinstance(text, str)
|
||||
text = _normalize_whitespace(text)
|
||||
assert len(text.strip()) > 0
|
||||
if history_prompt is not None:
|
||||
assert (history_prompt in ALLOWED_PROMPTS)
|
||||
semantic_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["text"]
|
||||
assert (
|
||||
isinstance(semantic_history, np.ndarray)
|
||||
and len(semantic_history.shape) == 1
|
||||
and len(semantic_history) > 0
|
||||
and semantic_history.min() >= 0
|
||||
and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
|
||||
)
|
||||
else:
|
||||
semantic_history = None
|
||||
model_container = load_model(use_gpu=use_gpu, model_type="text")
|
||||
if model is None:
|
||||
model = model_container["model"]
|
||||
tokenizer = model_container["tokenizer"]
|
||||
encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
if len(encoded_text) > 256:
|
||||
p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
|
||||
print(f"warning, text too long, lopping of last {p}%")
|
||||
encoded_text = encoded_text[:256]
|
||||
encoded_text = np.pad(
|
||||
encoded_text,
|
||||
(0, 256 - len(encoded_text)),
|
||||
constant_values=TEXT_PAD_TOKEN,
|
||||
mode="constant",
|
||||
)
|
||||
if semantic_history is not None:
|
||||
semantic_history = semantic_history.astype(np.int64)
|
||||
# lop off if history is too long, pad if needed
|
||||
semantic_history = semantic_history[-256:]
|
||||
semantic_history = np.pad(
|
||||
semantic_history,
|
||||
(0, 256 - len(semantic_history)),
|
||||
constant_values=SEMANTIC_PAD_TOKEN,
|
||||
mode="constant",
|
||||
)
|
||||
else:
|
||||
semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
|
||||
x = torch.from_numpy(
|
||||
np.hstack([encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])]).astype(np.int64)
|
||||
)[None]
|
||||
assert x.shape[1] == 256 + 256 + 1
|
||||
with _inference_mode():
|
||||
x = x.to(device)
|
||||
n_tot_steps = 768
|
||||
# custom tqdm updates since we don't know when eos will occur
|
||||
pbar = tqdm.tqdm(disable=silent, total=100)
|
||||
pbar_state = 0
|
||||
tot_generated_duration_s = 0
|
||||
for n in range(n_tot_steps):
|
||||
logits = model(x, merge_context=True)
|
||||
relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
|
||||
if allow_early_stop:
|
||||
relevant_logits = torch.hstack(
|
||||
(relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
|
||||
)
|
||||
if top_p is not None:
|
||||
# faster to convert to numpy
|
||||
logits_device = relevant_logits.device
|
||||
logits_dtype = relevant_logits.type()
|
||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||
sorted_logits = relevant_logits[sorted_indices]
|
||||
cumulative_probs = np.cumsum(softmax(sorted_logits))
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits)
|
||||
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
if allow_early_stop and (
|
||||
item_next == SEMANTIC_VOCAB_SIZE
|
||||
or (min_eos_p is not None and probs[-1] >= min_eos_p)
|
||||
):
|
||||
# eos found, so break
|
||||
pbar.update(100 - pbar_state)
|
||||
break
|
||||
x = torch.cat((x, item_next[None]), dim=1)
|
||||
tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
|
||||
if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
|
||||
pbar.update(100 - pbar_state)
|
||||
break
|
||||
if n == n_tot_steps - 1:
|
||||
pbar.update(100 - pbar_state)
|
||||
break
|
||||
del logits, relevant_logits, probs, item_next
|
||||
req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))])
|
||||
if req_pbar_state > pbar_state:
|
||||
pbar.update(req_pbar_state - pbar_state)
|
||||
pbar_state = req_pbar_state
|
||||
pbar.close()
|
||||
out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
|
||||
assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
|
||||
_clear_cuda_cache()
|
||||
return out
|
||||
|
||||
|
||||
def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
|
||||
assert len(arr.shape) == 2
|
||||
arr = arr.copy()
|
||||
if offset_size is not None:
|
||||
for n in range(1, arr.shape[0]):
|
||||
arr[n, :] += offset_size * n
|
||||
flat_arr = arr.ravel("F")
|
||||
return flat_arr
|
||||
|
||||
|
||||
COARSE_SEMANTIC_PAD_TOKEN = 12_048
|
||||
COARSE_INFER_TOKEN = 12_050
|
||||
|
||||
|
||||
def generate_coarse(
|
||||
x_semantic,
|
||||
history_prompt=None,
|
||||
temp=0.7,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
use_gpu=True,
|
||||
silent=False,
|
||||
max_coarse_history=630, # min 60 (faster), max 630 (more context)
|
||||
sliding_window_len=60,
|
||||
model=None,
|
||||
):
|
||||
"""Generate coarse audio codes from semantic tokens."""
|
||||
assert (
|
||||
isinstance(x_semantic, np.ndarray)
|
||||
and len(x_semantic.shape) == 1
|
||||
and len(x_semantic) > 0
|
||||
and x_semantic.min() >= 0
|
||||
and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
|
||||
)
|
||||
assert 60 <= max_coarse_history <= 630
|
||||
assert max_coarse_history + sliding_window_len <= 1024 - 256
|
||||
semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
|
||||
max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
|
||||
if history_prompt is not None:
|
||||
assert (history_prompt in ALLOWED_PROMPTS)
|
||||
x_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)
|
||||
x_semantic_history = x_history["coarse_1"]
|
||||
x_coarse_history = x_history["coarse_2"]
|
||||
assert (
|
||||
isinstance(x_semantic_history, np.ndarray)
|
||||
and len(x_semantic_history.shape) == 1
|
||||
and len(x_semantic_history) > 0
|
||||
and x_semantic_history.min() >= 0
|
||||
and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
|
||||
and isinstance(x_coarse_history, np.ndarray)
|
||||
and len(x_coarse_history.shape) == 2
|
||||
and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
|
||||
and x_coarse_history.shape[-1] >= 0
|
||||
and x_coarse_history.min() >= 0
|
||||
and x_coarse_history.max() <= CODEBOOK_SIZE - 1
|
||||
and (
|
||||
round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
|
||||
== round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
|
||||
)
|
||||
)
|
||||
x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
|
||||
# trim histories correctly
|
||||
n_semantic_hist_provided = np.min(
|
||||
[
|
||||
max_semantic_history,
|
||||
len(x_semantic_history) - len(x_semantic_history) % 2,
|
||||
int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
|
||||
]
|
||||
)
|
||||
n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
|
||||
x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
|
||||
x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
|
||||
# TODO: bit of a hack for time alignment (sounds better)
|
||||
x_coarse_history = x_coarse_history[:-2]
|
||||
else:
|
||||
x_semantic_history = np.array([], dtype=np.int32)
|
||||
x_coarse_history = np.array([], dtype=np.int32)
|
||||
if model is None:
|
||||
model = load_model(use_gpu=use_gpu, model_type="coarse")
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
# start loop
|
||||
n_steps = int(
|
||||
round(
|
||||
np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
|
||||
* N_COARSE_CODEBOOKS
|
||||
)
|
||||
)
|
||||
assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
|
||||
x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
|
||||
x_coarse = x_coarse_history.astype(np.int32)
|
||||
base_semantic_idx = len(x_semantic_history)
|
||||
with _inference_mode():
|
||||
x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
|
||||
x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
|
||||
n_window_steps = int(np.ceil(n_steps / sliding_window_len))
|
||||
n_step = 0
|
||||
for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
|
||||
semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
|
||||
# pad from right side
|
||||
x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
|
||||
x_in = x_in[:, :256]
|
||||
x_in = F.pad(
|
||||
x_in,
|
||||
(0, 256 - x_in.shape[-1]),
|
||||
"constant",
|
||||
COARSE_SEMANTIC_PAD_TOKEN,
|
||||
)
|
||||
x_in = torch.hstack(
|
||||
[
|
||||
x_in,
|
||||
torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
|
||||
x_coarse_in[:, -max_coarse_history:],
|
||||
]
|
||||
)
|
||||
for _ in range(sliding_window_len):
|
||||
if n_step >= n_steps:
|
||||
continue
|
||||
is_major_step = n_step % N_COARSE_CODEBOOKS == 0
|
||||
logits = model(x_in)
|
||||
logit_start_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
logit_end_idx = (
|
||||
SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
|
||||
)
|
||||
relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
|
||||
if top_p is not None:
|
||||
# faster to convert to numpy
|
||||
logits_device = relevant_logits.device
|
||||
logits_dtype = relevant_logits.type()
|
||||
relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
|
||||
sorted_indices = np.argsort(relevant_logits)[::-1]
|
||||
sorted_logits = relevant_logits[sorted_indices]
|
||||
cumulative_probs = np.cumsum(softmax(sorted_logits))
|
||||
sorted_indices_to_remove = cumulative_probs > top_p
|
||||
sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
|
||||
sorted_indices_to_remove[0] = False
|
||||
relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
|
||||
relevant_logits = torch.from_numpy(relevant_logits)
|
||||
relevant_logits = relevant_logits.to(logits_device).type(logits_dtype)
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
|
||||
relevant_logits[relevant_logits < v[-1]] = -float("Inf")
|
||||
probs = F.softmax(relevant_logits / temp, dim=-1)
|
||||
item_next = torch.multinomial(probs, num_samples=1)
|
||||
item_next += logit_start_idx
|
||||
x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
|
||||
x_in = torch.cat((x_in, item_next[None]), dim=1)
|
||||
del logits, relevant_logits, probs, item_next
|
||||
n_step += 1
|
||||
del x_in
|
||||
del x_semantic_in
|
||||
gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
|
||||
del x_coarse_in
|
||||
assert len(gen_coarse_arr) == n_steps
|
||||
gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
|
||||
for n in range(1, N_COARSE_CODEBOOKS):
|
||||
gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
|
||||
_clear_cuda_cache()
|
||||
return gen_coarse_audio_arr
|
||||
|
||||
|
||||
def generate_fine(
|
||||
x_coarse_gen,
|
||||
history_prompt=None,
|
||||
temp=0.5,
|
||||
use_gpu=True,
|
||||
silent=True,
|
||||
model=None,
|
||||
):
|
||||
"""Generate full audio codes from coarse audio codes."""
|
||||
assert (
|
||||
isinstance(x_coarse_gen, np.ndarray)
|
||||
and len(x_coarse_gen.shape) == 2
|
||||
and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
|
||||
and x_coarse_gen.shape[1] > 0
|
||||
and x_coarse_gen.min() >= 0
|
||||
and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
|
||||
)
|
||||
if history_prompt is not None:
|
||||
assert (history_prompt in ALLOWED_PROMPTS)
|
||||
x_fine_history = np.load(
|
||||
os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt}.npz")
|
||||
)["fine"]
|
||||
assert (
|
||||
isinstance(x_fine_history, np.ndarray)
|
||||
and len(x_fine_history.shape) == 2
|
||||
and x_fine_history.shape[0] == N_FINE_CODEBOOKS
|
||||
and x_fine_history.shape[1] >= 0
|
||||
and x_fine_history.min() >= 0
|
||||
and x_fine_history.max() <= CODEBOOK_SIZE - 1
|
||||
)
|
||||
else:
|
||||
x_fine_history = None
|
||||
n_coarse = x_coarse_gen.shape[0]
|
||||
if model is None:
|
||||
model = load_model(use_gpu=use_gpu, model_type="fine")
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
# make input arr
|
||||
in_arr = np.vstack(
|
||||
[
|
||||
x_coarse_gen,
|
||||
np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
|
||||
+ CODEBOOK_SIZE, # padding
|
||||
]
|
||||
).astype(np.int32)
|
||||
# prepend history if available (max 512)
|
||||
if x_fine_history is not None:
|
||||
x_fine_history = x_fine_history.astype(np.int32)
|
||||
in_arr = np.hstack(
|
||||
[
|
||||
x_fine_history[:, -512:].astype(np.int32),
|
||||
in_arr,
|
||||
]
|
||||
)
|
||||
n_history = x_fine_history[:, -512:].shape[1]
|
||||
else:
|
||||
n_history = 0
|
||||
n_remove_from_end = 0
|
||||
# need to pad if too short (since non-causal model)
|
||||
if in_arr.shape[1] < 1024:
|
||||
n_remove_from_end = 1024 - in_arr.shape[1]
|
||||
in_arr = np.hstack(
|
||||
[
|
||||
in_arr,
|
||||
np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
|
||||
]
|
||||
)
|
||||
# we can be lazy about fractional loop and just keep overwriting codebooks
|
||||
n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
|
||||
with _inference_mode():
|
||||
in_arr = torch.tensor(in_arr.T).to(device)
|
||||
for n in tqdm.tqdm(range(n_loops), disable=silent):
|
||||
start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
|
||||
start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
|
||||
rel_start_fill_idx = start_fill_idx - start_idx
|
||||
in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
|
||||
for nn in range(n_coarse, N_FINE_CODEBOOKS):
|
||||
logits = model(nn, in_buffer)
|
||||
if temp is None:
|
||||
relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
|
||||
codebook_preds = torch.argmax(relevant_logits, -1)
|
||||
else:
|
||||
relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
|
||||
probs = F.softmax(relevant_logits, dim=-1)
|
||||
codebook_preds = torch.hstack(
|
||||
[
|
||||
torch.multinomial(probs[n], num_samples=1)
|
||||
for n in range(rel_start_fill_idx, 1024)
|
||||
]
|
||||
)
|
||||
in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
|
||||
del logits, codebook_preds
|
||||
# transfer over info into model_in and convert to numpy
|
||||
for nn in range(n_coarse, N_FINE_CODEBOOKS):
|
||||
in_arr[
|
||||
start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
|
||||
] = in_buffer[0, rel_start_fill_idx:, nn]
|
||||
del in_buffer
|
||||
gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
|
||||
del in_arr
|
||||
gen_fine_arr = gen_fine_arr[:, n_history:]
|
||||
if n_remove_from_end > 0:
|
||||
gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
|
||||
assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
|
||||
_clear_cuda_cache()
|
||||
return gen_fine_arr
|
||||
|
||||
|
||||
def codec_decode(fine_tokens, model=None, use_gpu=True):
|
||||
"""Turn quantized audio codes into audio array using encodec."""
|
||||
if model is None:
|
||||
model = load_codec_model(use_gpu=use_gpu)
|
||||
device = "cuda" if use_gpu and torch.cuda.device_count() > 0 else "cpu"
|
||||
arr = torch.from_numpy(fine_tokens)[None]
|
||||
arr = arr.to(device)
|
||||
arr = arr.transpose(0, 1)
|
||||
emb = model.quantizer.decode(arr)
|
||||
out = model.decoder(emb)
|
||||
audio_arr = out.detach().cpu().numpy().squeeze()
|
||||
del arr, emb, out
|
||||
return audio_arr
|
||||
174
bark/model.py
Normal file
174
bark/model.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""
|
||||
Much of this code is adapted from Andrej Karpathy's NanoGPT
|
||||
(https://github.com/karpathy/nanoGPT)
|
||||
"""
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
||||
|
||||
def __init__(self, ndim, bias):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(ndim))
|
||||
self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
|
||||
|
||||
def forward(self, input):
|
||||
return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
|
||||
|
||||
class CausalSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.dropout = config.dropout
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
||||
if not self.flash:
|
||||
# print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
|
||||
# causal mask to ensure that attention is only applied to the left in the input sequence
|
||||
self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
|
||||
.view(1, 1, config.block_size, config.block_size))
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
|
||||
else:
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
|
||||
class MLP(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
|
||||
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
|
||||
self.dropout = nn.Dropout(config.dropout)
|
||||
self.gelu = nn.GELU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.c_fc(x)
|
||||
x = self.gelu(x)
|
||||
x = self.c_proj(x)
|
||||
x = self.dropout(x)
|
||||
return x
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, config, layer_idx):
|
||||
super().__init__()
|
||||
self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.attn = CausalSelfAttention(config)
|
||||
self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
|
||||
self.mlp = MLP(config)
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
@dataclass
|
||||
class GPTConfig:
|
||||
block_size: int = 1024
|
||||
input_vocab_size: int = 10_048
|
||||
output_vocab_size: int = 10_048
|
||||
n_layer: int = 12
|
||||
n_head: int = 12
|
||||
n_embd: int = 768
|
||||
dropout: float = 0.0
|
||||
bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
|
||||
|
||||
class GPT(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.input_vocab_size is not None
|
||||
assert config.output_vocab_size is not None
|
||||
assert config.block_size is not None
|
||||
self.config = config
|
||||
|
||||
self.transformer = nn.ModuleDict(dict(
|
||||
wte = nn.Embedding(config.input_vocab_size, config.n_embd),
|
||||
wpe = nn.Embedding(config.block_size, config.n_embd),
|
||||
drop = nn.Dropout(config.dropout),
|
||||
h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
|
||||
ln_f = LayerNorm(config.n_embd, bias=config.bias),
|
||||
))
|
||||
self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
|
||||
|
||||
def get_num_params(self, non_embedding=True):
|
||||
"""
|
||||
Return the number of parameters in the model.
|
||||
For non-embedding count (default), the position embeddings get subtracted.
|
||||
The token embeddings would too, except due to the parameter sharing these
|
||||
params are actually used as weights in the final layer, so we include them.
|
||||
"""
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
if non_embedding:
|
||||
n_params -= self.transformer.wte.weight.numel()
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if merge_context:
|
||||
assert(idx.shape[1] >= 256+256+1)
|
||||
t = idx.shape[1] - 256
|
||||
else:
|
||||
assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
|
||||
# forward the GPT model itself
|
||||
if merge_context:
|
||||
tok_emb = torch.cat([
|
||||
self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
|
||||
self.transformer.wte(idx[:,256+256:])
|
||||
], dim=1)
|
||||
else:
|
||||
tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
|
||||
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
|
||||
x = self.transformer.drop(tok_emb + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
return logits
|
||||
149
bark/model_fine.py
Normal file
149
bark/model_fine.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
Much of this code is adapted from Andrej Karpathy's NanoGPT
|
||||
(https://github.com/karpathy/nanoGPT)
|
||||
"""
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .model import GPT, GPTConfig, MLP
|
||||
|
||||
|
||||
class NonCausalSelfAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
assert config.n_embd % config.n_head == 0
|
||||
# key, query, value projections for all heads, but in a batch
|
||||
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
|
||||
# output projection
|
||||
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
|
||||
# regularization
|
||||
self.attn_dropout = nn.Dropout(config.dropout)
|
||||
self.resid_dropout = nn.Dropout(config.dropout)
|
||||
self.n_head = config.n_head
|
||||
self.n_embd = config.n_embd
|
||||
self.dropout = config.dropout
|
||||
# flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
|
||||
self.flash = (
|
||||
hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
|
||||
|
||||
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
||||
q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
|
||||
k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
|
||||
|
||||
# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
|
||||
if self.flash:
|
||||
# efficient attention using Flash Attention CUDA kernels
|
||||
y = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
|
||||
)
|
||||
else:
|
||||
# manual implementation of attention
|
||||
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
||||
att = F.softmax(att, dim=-1)
|
||||
att = self.attn_dropout(att)
|
||||
y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
|
||||
y = (
|
||||
y.transpose(1, 2).contiguous().view(B, T, C)
|
||||
) # re-assemble all head outputs side by side
|
||||
|
||||
# output projection
|
||||
y = self.resid_dropout(self.c_proj(y))
|
||||
return y
|
||||
|
||||
|
||||
class FineBlock(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.ln_1 = nn.LayerNorm(config.n_embd)
|
||||
self.attn = NonCausalSelfAttention(config)
|
||||
self.ln_2 = nn.LayerNorm(config.n_embd)
|
||||
self.mlp = MLP(config)
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.attn(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class FineGPT(GPT):
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
del self.lm_head
|
||||
self.config = config
|
||||
self.n_codes_total = config.n_codes_total
|
||||
self.transformer = nn.ModuleDict(
|
||||
dict(
|
||||
wtes=nn.ModuleList(
|
||||
[
|
||||
nn.Embedding(config.input_vocab_size, config.n_embd)
|
||||
for _ in range(config.n_codes_total)
|
||||
]
|
||||
),
|
||||
wpe=nn.Embedding(config.block_size, config.n_embd),
|
||||
drop=nn.Dropout(config.dropout),
|
||||
h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
|
||||
ln_f=nn.LayerNorm(config.n_embd),
|
||||
)
|
||||
)
|
||||
self.lm_heads = nn.ModuleList(
|
||||
[
|
||||
nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
|
||||
for _ in range(config.n_codes_given, self.n_codes_total)
|
||||
]
|
||||
)
|
||||
for i in range(self.n_codes_total - config.n_codes_given):
|
||||
self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
|
||||
|
||||
def forward(self, pred_idx, idx):
|
||||
device = idx.device
|
||||
b, t, codes = idx.size()
|
||||
assert (
|
||||
t <= self.config.block_size
|
||||
), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
|
||||
assert pred_idx > 0, "cannot predict 0th codebook"
|
||||
assert codes == self.n_codes_total, (b, t, codes)
|
||||
pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
|
||||
|
||||
# forward the GPT model itself
|
||||
tok_embs = [
|
||||
wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
|
||||
] # token embeddings of shape (b, t, n_embd)
|
||||
tok_emb = torch.cat(tok_embs, dim=-1)
|
||||
pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
|
||||
x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
|
||||
x = self.transformer.drop(x + pos_emb)
|
||||
for block in self.transformer.h:
|
||||
x = block(x)
|
||||
x = self.transformer.ln_f(x)
|
||||
logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
|
||||
return logits
|
||||
|
||||
def get_num_params(self, non_embedding=True):
|
||||
"""
|
||||
Return the number of parameters in the model.
|
||||
For non-embedding count (default), the position embeddings get subtracted.
|
||||
The token embeddings would too, except due to the parameter sharing these
|
||||
params are actually used as weights in the final layer, so we include them.
|
||||
"""
|
||||
n_params = sum(p.numel() for p in self.parameters())
|
||||
if non_embedding:
|
||||
for wte in self.transformer.wtes:
|
||||
n_params -= wte.weight.numel()
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
|
||||
@dataclass
|
||||
class FineGPTConfig(GPTConfig):
|
||||
n_codes_total: int = 8
|
||||
n_codes_given: int = 1
|
||||
Reference in New Issue
Block a user