Add v1 finetune support

This commit is contained in:
Francis LaBounty
2023-06-29 21:48:18 -06:00
parent 572de4b707
commit 50927298a0
10 changed files with 3243 additions and 45 deletions

5
.gitignore vendored
View File

@@ -2,4 +2,7 @@ __pycache__/
*.wav
_temp/
models/
output.npz
wandb/
*_output/
output.npz
joe_biden_state_of_union/

View File

@@ -3,6 +3,7 @@ import gc
import hashlib
import os
import re
import json
from encodec import EncodecModel
import funcy
@@ -203,42 +204,81 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
raise NotImplementedError()
model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
model_info = REMOTE_MODEL_PATHS[model_key]
if (
os.path.exists(ckpt_path) and
_md5(ckpt_path) != model_info["checksum"]
):
logger.warning(f"found outdated {model_type} model, removing.")
os.remove(ckpt_path)
# if (
# os.path.exists(ckpt_path) and
# _md5(ckpt_path) != model_info["checksum"]
# ):
# logger.warning(f"found outdated {model_type} model, removing.")
# os.remove(ckpt_path)
if not os.path.exists(ckpt_path):
logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
_download(model_info["repo_id"], model_info["file_name"], ckpt_path)
checkpoint = torch.load(ckpt_path, map_location=device)
# this is a hack
model_args = checkpoint["model_args"]
# check if config.json is in the same directory as the checkpoint
# if so, load it
# otherwise, assume it's in the checkpoint
config_path = os.path.join(os.path.dirname(ckpt_path), "config.json")
if os.path.exists(config_path):
with open(config_path, "r") as f:
model_args = json.load(f)
else:
model_args = checkpoint["model_args"]
if "input_vocab_size" not in model_args:
model_args["input_vocab_size"] = model_args["vocab_size"]
model_args["output_vocab_size"] = model_args["vocab_size"]
del model_args["vocab_size"]
gptconf = ConfigClass(**checkpoint["model_args"])
gptconf = ConfigClass(**model_args)
model = ModelClass(gptconf)
state_dict = checkpoint["model"]
if checkpoint.get("model", None) is not None:
state_dict = checkpoint["model"]
else:
state_dict = checkpoint
# fixup checkpoint
unwanted_prefix = "_orig_mod."
for k, v in list(state_dict.items()):
if k.startswith(unwanted_prefix):
state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
unwanted_suffixes = [
"lora_right_weight",
"lora_left_weight",
"lora_right_bias",
"lora_left_bias",
]
for k, v in list(state_dict.items()):
for suffix in unwanted_suffixes:
if k.endswith(suffix):
state_dict.pop(k)
# super hacky - should probably refactor this
if state_dict.get('lm_head.0.weight', None) is not None:
state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight')
if state_dict.get('lm_heads.0.0.weight', None) is not None:
state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight')
if state_dict.get('lm_heads.1.0.weight', None) is not None:
state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight')
if state_dict.get('lm_heads.2.0.weight', None) is not None:
state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight')
if state_dict.get('lm_heads.3.0.weight', None) is not None:
state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight')
if state_dict.get('lm_heads.4.0.weight', None) is not None:
state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight')
if state_dict.get('lm_heads.5.0.weight', None) is not None:
state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight')
if state_dict.get('lm_heads.6.0.weight', None) is not None:
state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight')
extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
if len(extra_keys) != 0:
raise ValueError(f"extra keys found: {extra_keys}")
print(f"extra keys found: {extra_keys}")
if len(missing_keys) != 0:
raise ValueError(f"missing keys: {missing_keys}")
model.load_state_dict(state_dict, strict=False)
n_params = model.get_num_params()
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
if checkpoint.get("best_val_loss", None) is not None:
val_loss = checkpoint["best_val_loss"].item()
logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
model.eval()
model.to(device)
del checkpoint, state_dict
@@ -273,8 +313,11 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te
models_devices[model_key] = device
device = "cpu"
if model_key not in models or force_reload:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
clean_models(model_key=model_key)
if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"):
ckpt_path = path
else:
ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path)
# clean_models(model_key=model_key)
model = _load_model_f(ckpt_path, device)
models[model_key] = model
if model_type == "text":
@@ -306,10 +349,13 @@ def load_codec_model(use_gpu=True, force_reload=False):
def preload_models(
text_use_gpu=True,
text_use_small=False,
text_model_path=None,
coarse_use_gpu=True,
coarse_use_small=False,
coarse_model_path=None,
fine_use_gpu=True,
fine_use_small=False,
fine_model_path=None,
codec_use_gpu=True,
force_reload=False,
path=None,
@@ -320,17 +366,17 @@ def preload_models(
):
logger.warning("No GPU being used. Careful, inference might be very slow!")
_ = load_model(
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path
model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path
)
_ = load_model(
model_type="coarse",
use_gpu=coarse_use_gpu,
use_small=coarse_use_small,
force_reload=force_reload,
path=path,
path=path if coarse_model_path is None else coarse_model_path,
)
_ = load_model(
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path
model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path
)
_ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)

View File

@@ -9,7 +9,7 @@ import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat, reduce
SEMANTIC_PAD_TOKEN = 10_000
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@@ -167,7 +167,7 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, training=False):
device = idx.device
b, t = idx.size()
if past_kv is not None:
@@ -215,19 +215,9 @@ class GPT(nn.Module):
x = self.transformer.ln_f(x)
if labels is not None:
if training:
logits = self.lm_head(x)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return logits, loss
return logits
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim

0
datasets/.tmp Normal file
View File

View File

@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": null,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -20,7 +20,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [

183
test_models.ipynb Normal file
View File

@@ -0,0 +1,183 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from bark.api import generate_audio\n",
"from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"semantic_path = \"E:/Python/bark-with-voice-clone/semantic_output/pytorch_model.bin\"\n",
"coarse_path = \"E:/Python/bark-with-voice-clone/coarse_output/pytorch_model.bin\"\n",
"fine_path = \"E:/Python/bark-with-voice-clone/fine_output/pytorch_model.bin\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"preload_models(\n",
" text_use_gpu=True,\n",
" text_use_small=False,\n",
" text_model_path=semantic_path,\n",
" coarse_use_gpu=True,\n",
" coarse_use_small=False,\n",
" coarse_model_path=coarse_path,\n",
" fine_use_gpu=True,\n",
" fine_use_small=False,\n",
" fine_model_path=fine_path,\n",
" codec_use_gpu=True,\n",
" force_reload=False,\n",
" path=\"models\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# simple generation\n",
"text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! [laughs] A lot better than the original!\"\n",
"audio_array = generate_audio(text_prompt, history_prompt=None, text_temp=0.7, waveform_temp=0.7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Audio\n",
"# play audio\n",
"Audio(audio_array, rate=SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.io.wavfile import write as write_wav\n",
"# save audio\n",
"filepath = \"output/audio.wav\" # change this to your desired output path\n",
"write_wav(filepath, SAMPLE_RATE, audio_array)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, semantic_top_p=0.95, coarse_temp=0.7, coarse_top_k=50, coarse_top_p=0.95, fine_temp=0.5, voice_name=None, use_semantic_history_prompt=True, use_coarse_history_prompt=True, use_fine_history_prompt=True, output_full=False):\n",
" # generation with more control\n",
" x_semantic = generate_text_semantic(\n",
" text_prompt,\n",
" history_prompt=voice_name if use_semantic_history_prompt else None,\n",
" temp=semantic_temp,\n",
" top_k=semantic_top_k,\n",
" top_p=semantic_top_p,\n",
" )\n",
"\n",
" x_coarse_gen = generate_coarse(\n",
" x_semantic,\n",
" history_prompt=voice_name if use_coarse_history_prompt else None,\n",
" temp=coarse_temp,\n",
" top_k=coarse_top_k,\n",
" top_p=coarse_top_p,\n",
" )\n",
" x_fine_gen = generate_fine(\n",
" x_coarse_gen,\n",
" history_prompt=voice_name if use_fine_history_prompt else None,\n",
" temp=fine_temp,\n",
" )\n",
"\n",
" if output_full:\n",
" full_generation = {\n",
" 'semantic_prompt': x_semantic,\n",
" 'coarse_prompt': x_coarse_gen,\n",
" 'fine_prompt': x_fine_gen,\n",
" }\n",
" return full_generation, codec_decode(x_fine_gen)\n",
" return codec_decode(x_fine_gen)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! [laughs] A lot better than the original!\"\n",
"\n",
"audio_array = generate_with_settings(\n",
" text_prompt,\n",
" semantic_temp=0.7,\n",
" semantic_top_k=50,\n",
" semantic_top_p=0.99,\n",
" coarse_temp=0.7,\n",
" coarse_top_k=50,\n",
" coarse_top_p=0.99,\n",
" fine_temp=0.5,\n",
" voice_name=None,\n",
" use_semantic_history_prompt=True,\n",
" use_coarse_history_prompt=True,\n",
" use_fine_history_prompt=True,\n",
" output_full=False\n",
")\n",
"\n",
"from IPython.display import Audio\n",
"# play audio\n",
"Audio(audio_array, rate=SAMPLE_RATE)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.io.wavfile import write as write_wav\n",
"# save audio\n",
"filepath = \"output/audio.wav\" # change this to your desired output path\n",
"write_wav(filepath, SAMPLE_RATE, audio_array)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

936
train_coarse.ipynb Normal file
View File

@@ -0,0 +1,936 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import os\n",
"import re\n",
"import gc\n",
"import math\n",
"import json\n",
"import hashlib\n",
"import numpy as np\n",
"import logging\n",
"import torchaudio\n",
"from tqdm.auto import tqdm\n",
"import torch.nn.functional as F\n",
"from encodec.utils import convert_audio\n",
"from accelerate import Accelerator\n",
"from accelerate.utils import set_seed\n",
"from transformers import BertTokenizer\n",
"from huggingface_hub import hf_hub_download\n",
"from packaging import version\n",
"from diffusers.optimization import get_scheduler\n",
"\n",
"from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
"from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
"from bark.model import GPTConfig, GPT\n",
"from bark.model_fine import FineGPT, FineGPTConfig"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training Args"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_batch_size = 8\n",
"eval_batch_size = 8\n",
"grad_accum = 1\n",
"ckpt_path = 'models/coarse_2.pt'\n",
"model_type = \"coarse\"\n",
"dataset_path = 'datasets/joe_biden_state_of_union/'\n",
"logging_dir = 'logs/'\n",
"log_with = 'wandb'\n",
"hubert_path = 'data/models/hubert/hubert.pt'\n",
"hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
"\n",
"output_dir = 'coarse_output/'\n",
"resume_from_checkpoint = None\n",
"\n",
"checkpointing_steps = 1000\n",
"\n",
"mixed_precision = 'bf16'\n",
"bits = 16 #4 4 and 8 bit are a work in progress\n",
"compute_dtype = torch.bfloat16\n",
"double_quant = True\n",
"quant_type = 'nf4'\n",
"\n",
"lora_dim = 64\n",
"lora_scaling = 1\n",
"lora_dropout = 0.1\n",
"lora_module_name = 'transformer.h'\n",
"optimize_lora_params_only = True\n",
"\n",
"learning_rate = 1e-4\n",
"scale_lr = False\n",
"use_8bit_adam = False\n",
"adam_beta1 = 0.9\n",
"adam_beta2 = 0.999\n",
"adam_epsilon = 1e-8\n",
"weight_decay = 0.01\n",
"\n",
"llm_int8_skip_modules = None\n",
"keep_in_fp32_modules = ['lm_head']\n",
"\n",
"lr_scheduler_type = 'linear'\n",
"lr_warmup_steps = 200\n",
"num_train_epochs = 20\n",
"max_train_steps = None\n",
"max_grad_norm = 1.0\n",
"\n",
"semantic_cross_entropy_loss_weight = 0\n",
"\n",
"seed = 741"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_WINDOW_SIZE = 1024\n",
"\n",
"MAX_SEMANTIC_LEN = 256\n",
"\n",
"SEMANTIC_RATE_HZ = 49.9\n",
"SEMANTIC_VOCAB_SIZE = 10_000\n",
"\n",
"TEXT_ENCODING_OFFSET = 10_048\n",
"SEMANTIC_PAD_TOKEN = 10_000\n",
"TEXT_PAD_TOKEN = 129_595\n",
"SEMANTIC_INFER_TOKEN = 129_599\n",
"\n",
"MAX_COARSE_LEN = 768\n",
"\n",
"SAMPLE_RATE = 24_000\n",
"CHANNELS = 1\n",
"\n",
"COARSE_SEMANTIC_PAD_TOKEN = 12_048\n",
"COARSE_INFER_TOKEN = 12_050\n",
"\n",
"CODEBOOK_SIZE = 1024\n",
"N_COARSE_CODEBOOKS = 2\n",
"N_FINE_CODEBOOKS = 8\n",
"COARSE_RATE_HZ = 75\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"\n",
"USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
"\n",
"default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
"CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
"\n",
"\n",
"def _clear_cuda_cache():\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.synchronize()\n",
"\n",
"\n",
"def _md5(fname):\n",
" hash_md5 = hashlib.md5()\n",
" with open(fname, \"rb\") as f:\n",
" for chunk in iter(lambda: f.read(4096), b\"\"):\n",
" hash_md5.update(chunk)\n",
" return hash_md5.hexdigest()\n",
"\n",
"\n",
"def _download(from_hf_path, file_name, to_local_path):\n",
" to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
" path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
" os.makedirs(path, exist_ok=True)\n",
" hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
" os.replace(os.path.join(path, file_name), to_local_path)\n",
"\n",
"\n",
"def _tokenize(tokenizer, text):\n",
" return tokenizer.encode(text, add_special_tokens=False)\n",
"\n",
"\n",
"def _detokenize(tokenizer, enc_text):\n",
" return tokenizer.decode(enc_text)\n",
"\n",
"\n",
"def _normalize_whitespace(text):\n",
" return re.sub(r\"\\s+\", \" \", text).strip()\n",
"\n",
"\n",
"REMOTE_MODEL_PATHS = {\n",
" \"text_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"text.pt\",\n",
" \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
" },\n",
" \"coarse_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"coarse.pt\",\n",
" \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
" },\n",
" \"fine_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"fine.pt\",\n",
" \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
" },\n",
" \"text\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"text_2.pt\",\n",
" \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
" },\n",
" \"coarse\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"coarse_2.pt\",\n",
" \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
" },\n",
" \"fine\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"fine_2.pt\",\n",
" \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
" },\n",
"}\n",
"\n",
"\n",
"def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
" if model_type == \"text\":\n",
" ConfigClass = GPTConfig\n",
" ModelClass = GPT\n",
" elif model_type == \"coarse\":\n",
" ConfigClass = GPTConfig\n",
" ModelClass = GPT\n",
" elif model_type == \"fine\":\n",
" ConfigClass = FineGPTConfig\n",
" ModelClass = FineGPT\n",
" else:\n",
" raise NotImplementedError()\n",
" model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
" model_info = REMOTE_MODEL_PATHS[model_key]\n",
" if ckpt_path in [None, '']:\n",
" ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
" if not os.path.exists(ckpt_path):\n",
" logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
" _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
" checkpoint = torch.load(ckpt_path, map_location=device)\n",
" # this is a hack\n",
" model_args = checkpoint[\"model_args\"]\n",
" if \"input_vocab_size\" not in model_args:\n",
" model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
" model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
" del model_args[\"vocab_size\"]\n",
" gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
" model = ModelClass(gptconf)\n",
" state_dict = checkpoint[\"model\"]\n",
" # fixup checkpoint\n",
" unwanted_prefix = \"_orig_mod.\"\n",
" for k, v in list(state_dict.items()):\n",
" if k.startswith(unwanted_prefix):\n",
" state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
" extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
" extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
" missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
" missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
" if len(extra_keys) != 0:\n",
" raise ValueError(f\"extra keys found: {extra_keys}\")\n",
" if len(missing_keys) != 0:\n",
" raise ValueError(f\"missing keys: {missing_keys}\")\n",
" model.load_state_dict(state_dict, strict=False)\n",
" n_params = model.get_num_params()\n",
" val_loss = checkpoint[\"best_val_loss\"].item()\n",
" print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
" del checkpoint, state_dict\n",
" _clear_cuda_cache()\n",
" if model_type == \"text\":\n",
" tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
" return model, tokenizer\n",
" return model\n",
"\n",
"\n",
"def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n",
" assert len(arr.shape) == 2\n",
" arr = arr.copy()\n",
" if offset_size is not None:\n",
" for n in range(1, arr.shape[0]):\n",
" arr[n, :] += offset_size * n\n",
" flat_arr = arr.ravel(\"F\")\n",
" return flat_arr\n",
"\n",
"\n",
"def load_filepaths_and_text(filename, split=\"|\"):\n",
" with open(filename, encoding='utf-8') as f:\n",
" filepaths_and_text = [line.strip().split(split) for line in f]\n",
" base = os.path.dirname(filename)\n",
" for j in range(len(filepaths_and_text)):\n",
" filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
" return filepaths_and_text\n",
"\n",
"\n",
"class TtsDataset(torch.utils.data.Dataset):\n",
" def __init__(self, opt):\n",
" self.path = os.path.dirname(opt['path'])\n",
" self.mode = opt['mode']\n",
" self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n",
"\n",
" def __getitem__(self, index):\n",
" audiopath_and_text = self.audiopaths_and_text[index]\n",
" audiopath = audiopath_and_text[0]\n",
"\n",
" tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
" semantic_tokens = tokens['semantic']\n",
" coarse_tokens = _flatten_codebooks(tokens['coarse'], offset_size=CODEBOOK_SIZE) + SEMANTIC_VOCAB_SIZE\n",
"\n",
" return torch.from_numpy(semantic_tokens), torch.from_numpy(coarse_tokens)\n",
"\n",
" def __len__(self):\n",
" return len(self.audiopaths_and_text)\n",
"\n",
"\n",
"class TtsCollater():\n",
" def __init__(self):\n",
" pass\n",
" def __call__(self, batch):\n",
" max_semantic_len = MAX_SEMANTIC_LEN\n",
" max_coarse_len = MAX_COARSE_LEN\n",
" semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS\n",
" semantic_tokens = []\n",
" coarse_tokens = []\n",
"\n",
" for b in batch:\n",
" semantic_tokens_, coarse_tokens_ = b\n",
" start_idx = None\n",
" if len(semantic_tokens_) > max_semantic_len:\n",
" start_idx = np.random.randint(0, len(semantic_tokens_) - max_semantic_len + 1)\n",
" semantic_tokens_ = semantic_tokens_[start_idx:start_idx+max_semantic_len]\n",
" semantic_tokens_ = F.pad(semantic_tokens_, (0, max_semantic_len-len(semantic_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n",
" semantic_tokens_ = torch.cat([semantic_tokens_, torch.tensor([COARSE_INFER_TOKEN])])\n",
" semantic_tokens.append(semantic_tokens_)\n",
"\n",
" if start_idx is not None:\n",
" start_idx_coarse = int(start_idx * semantic_to_coarse_ratio) \n",
" coarse_tokens_ = coarse_tokens_[start_idx_coarse:start_idx_coarse+max_coarse_len]\n",
" coarse_tokens_ = F.pad(coarse_tokens_, (0, max_coarse_len-len(coarse_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n",
" coarse_tokens.append(coarse_tokens_)\n",
"\n",
" return {\n",
" 'semantic_tokens': torch.stack(semantic_tokens).contiguous(),\n",
" 'coarse_tokens': torch.stack(coarse_tokens).contiguous()\n",
" }\n",
" \n",
"\n",
"accelerator = Accelerator(\n",
" gradient_accumulation_steps=grad_accum,\n",
" mixed_precision=mixed_precision,\n",
" log_with=log_with,\n",
" logging_dir=logging_dir,\n",
")\n",
"device = accelerator.device\n",
"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"set_seed(seed)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup Dataset (only need to do this once)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
"\n",
"# path = dataset_path\n",
"\n",
"# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
"# from hubert.hubert_manager import HuBERTManager\n",
"# hubert_manager = HuBERTManager()\n",
"# from hubert.pre_kmeans_hubert import CustomHubert\n",
"# from hubert.customtokenizer import CustomTokenizer\n",
"# hubert_manager.make_sure_hubert_installed()\n",
"# hubert_manager.make_sure_tokenizer_installed()\n",
"\n",
"# # Load the HuBERT model\n",
"# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
"# hubert_model.eval()\n",
"# for param in hubert_model.parameters():\n",
"# param.requires_grad = False\n",
"\n",
"# # Load the CustomTokenizer model\n",
"# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
"\n",
"# from bark.generation import load_codec_model\n",
"# codec_model = load_codec_model(use_gpu=True)\n",
"# codec_model.eval()\n",
"# for param in codec_model.parameters():\n",
"# param.requires_grad = False\n",
"\n",
"\n",
"# def get_duration(wav, sr):\n",
"# return wav.shape[1] / sr\n",
"\n",
"# valid_lines_train = []\n",
"# # convert wavs to semantic tokens\n",
"# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
"# wav, sr = torchaudio.load(wav_path)\n",
"# if not get_duration(wav, sr) > max_duration_sec:\n",
"# valid_lines_train.append((wav_path, txt))\n",
"# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
"\n",
"# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
"# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
"\n",
"# # save semantic tokens\n",
"# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
"# semantic_tokens = semantic_tokens.cpu().numpy()\n",
"\n",
"# # Extract discrete codes from EnCodec\n",
"# with torch.no_grad():\n",
"# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
"# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
"\n",
"# # move codes to cpu\n",
"# codes = codes.cpu().numpy()\n",
"\n",
"# # save tokens\n",
"# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
"\n",
"# # rewrite train.txt with valid lines\n",
"# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
"# for wav_path, txt in valid_lines_train:\n",
"# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
"# f.write(f'{wav_path}|{txt}\\n')\n",
"\n",
"# valid_lines_valid = []\n",
"# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
"# wav, sr = torchaudio.load(wav_path)\n",
"# if not get_duration(wav, sr) > max_duration_sec:\n",
"# valid_lines_valid.append((wav_path, txt))\n",
"# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
"\n",
"# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
"# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
"\n",
"# # save semantic tokens\n",
"# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
"# semantic_tokens = semantic_tokens.cpu().numpy()\n",
" \n",
"# # Extract discrete codes from EnCodec\n",
"# with torch.no_grad():\n",
"# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
"# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
"\n",
"# # move codes to cpu\n",
"# codes = codes.cpu().numpy()\n",
"\n",
"# # save tokens\n",
"# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
"\n",
"# # rewrite valid.txt with valid lines\n",
"# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
"# for wav_path, txt in valid_lines_valid:\n",
"# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
"# f.write(f'{wav_path}|{txt}\\n')\n",
"\n",
"# del hubert_model\n",
"# del hubert_tokenizer\n",
"# del codec_model\n",
"# gc.collect()\n",
"# torch.cuda.empty_cache()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if scale_lr:\n",
" learning_rate = (\n",
" learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
" )\n",
"\n",
"if use_8bit_adam:\n",
" try:\n",
" import bitsandbytes as bnb\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
" )\n",
"\n",
" optimizer_class = bnb.optim.AdamW8bit\n",
"else:\n",
" optimizer_class = torch.optim.AdamW"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"quantization_config=BitsAndBytesConfig(\n",
" load_in_4bit=bits == 4,\n",
" load_in_8bit=bits == 8,\n",
" llm_int8_threshold=6.0,\n",
" llm_int8_has_fp16_weight=False,\n",
" bnb_4bit_compute_dtype=compute_dtype,\n",
" bnb_4bit_use_double_quant=double_quant,\n",
" bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
")\n",
"\n",
"# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
"# if quantization_config.load_in_8bit:\n",
"# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
"# elif quantization_config.load_in_4bit:\n",
"# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
"\n",
"# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
"# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
"# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
"# else:\n",
"# modules_to_not_convert = llm_int8_skip_modules\n",
"\n",
"# if not isinstance(modules_to_not_convert, list):\n",
"# modules_to_not_convert = [modules_to_not_convert]\n",
"\n",
"# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
"\n",
"# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
"\n",
"# if quantization_config.load_in_4bit and not supports_4bit:\n",
"# raise ValueError(\n",
"# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
"# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
"# )\n",
" \n",
"# if len(modules_to_not_convert) == 0:\n",
"# modules_to_not_convert = None\n",
"\n",
"# model = replace_with_bnb_linear(\n",
"# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
"# )\n",
"\n",
"# # training in 8-bit is only available in 0.37.0+\n",
"# model._is_kbit_training_enabled = version.parse(\n",
"# importlib_metadata.version(\"bitsandbytes\")\n",
"# ) >= version.parse(\"0.37.0\")\n",
"\n",
"# model.config.quantization_config = quantization_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if bits == 4:\n",
" from accelerate.utils import CustomDtype\n",
" target_dtype = CustomDtype.INT4\n",
"elif bits == 8:\n",
" target_dtype = torch.int8\n",
"\n",
"if lora_dim > 0:\n",
" for param in model.parameters():\n",
" if param.ndim == 1:\n",
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
" param.data = param.data.to(torch.float32)\n",
" \n",
" class CastOutputToFloat(nn.Sequential):\n",
" def forward(self, x):\n",
" return super().forward(x).to(torch.float32)\n",
"\n",
" model.lm_head = CastOutputToFloat(model.lm_head)\n",
"\n",
" model = convert_linear_layer_to_lora(model, lora_module_name,\n",
" lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
" lora_dropout=lora_dropout)\n",
" if optimize_lora_params_only:\n",
" model = only_optimize_lora_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"params_to_optimize = (\n",
" param for param in model.parameters() if param.requires_grad\n",
" )\n",
"\n",
"optimizer = optimizer_class(\n",
" params_to_optimize,\n",
" lr=learning_rate,\n",
" betas=(adam_beta1, adam_beta2),\n",
" weight_decay=weight_decay,\n",
" eps=adam_epsilon,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_train = {\n",
" 'path': dataset_path,\n",
" 'mode': 'train',\n",
"}\n",
"\n",
"opt_val = {\n",
" 'path': dataset_path,\n",
" 'mode': 'valid',\n",
"}\n",
"\n",
"train_dataset = TtsDataset(opt_train)\n",
"validation_dataset = TtsDataset(opt_val)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(\n",
" train_dataset,\n",
" batch_size=train_batch_size,\n",
" collate_fn=TtsCollater(),\n",
")\n",
"\n",
"validation_dataloader = torch.utils.data.DataLoader(\n",
" validation_dataset,\n",
" batch_size=eval_batch_size,\n",
" collate_fn=TtsCollater(),\n",
")\n",
"\n",
"criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n",
"\n",
"# Scheduler and math around the number of training steps.\n",
"overrode_max_train_steps = False\n",
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
"if max_train_steps is None:\n",
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
" overrode_max_train_steps = True\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" lr_scheduler_type,\n",
" optimizer=optimizer,\n",
" num_warmup_steps=lr_warmup_steps * grad_accum,\n",
" num_training_steps=max_train_steps * grad_accum,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
" model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
")\n",
"accelerator.register_for_checkpointing(lr_scheduler)\n",
"\n",
"weight_dtype = torch.float32\n",
"if accelerator.mixed_precision == \"fp16\":\n",
" weight_dtype = torch.float16\n",
"elif accelerator.mixed_precision == \"bf16\":\n",
" weight_dtype = torch.bfloat16"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
"if overrode_max_train_steps:\n",
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
"# Afterwards we recalculate our number of training epochs\n",
"num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
"\n",
"# We need to initialize the trackers we use, and also store our configuration.\n",
"# The trackers initializes automatically on the main process.\n",
"if accelerator.is_main_process:\n",
" accelerator.init_trackers(\"bark_coarse\", config={})\n",
"\n",
"# Train!\n",
"total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
"logger.info(\"***** Running training *****\")\n",
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
"logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
"logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
"logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
"logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
"logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
"logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
"global_step = 0\n",
"first_epoch = 0\n",
"\n",
"if resume_from_checkpoint:\n",
" if resume_from_checkpoint != \"latest\":\n",
" path = os.path.basename(resume_from_checkpoint)\n",
" else:\n",
" # Get the most recent checkpoint\n",
" dirs = os.listdir(output_dir)\n",
" dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
" dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
" path = dirs[-1]\n",
" accelerator.print(f\"Resuming from checkpoint {path}\")\n",
" accelerator.load_state(os.path.join(output_dir, path))\n",
" global_step = int(path.split(\"-\")[1])\n",
"\n",
" resume_global_step = global_step * grad_accum\n",
" first_epoch = resume_global_step // num_update_steps_per_epoch\n",
" resume_step = resume_global_step % num_update_steps_per_epoch\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if accelerator.is_main_process:\n",
" model.eval()\n",
" validation_loss = 0.0\n",
" num_batches = 0\n",
" num_samples = 0\n",
" with torch.no_grad():\n",
" for val_step, val_batch in enumerate(validation_dataloader):\n",
" # Similar to training, process the validation batch\n",
" val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n",
" val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n",
" val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n",
"\n",
" # Forward pass for validation\n",
" val_logits = model(val_inputs, training=True)\n",
" val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n",
"\n",
" # Calculate the validation loss\n",
" val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
" validation_loss += val_loss.item()\n",
" num_batches += 1\n",
" num_samples += val_batch['semantic_tokens'].size(0)\n",
"\n",
" average_validation_loss = validation_loss / num_batches\n",
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Only show the progress bar once on each machine.\n",
"progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
"progress_bar.set_description(\"Steps\")\n",
"\n",
"for epoch in range(first_epoch, num_train_epochs):\n",
" model.train()\n",
" for step, batch in enumerate(train_dataloader):\n",
" # Skip steps until we reach the resumed step\n",
" if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
" if step % grad_accum == 0:\n",
" progress_bar.update(1)\n",
" continue\n",
"\n",
" with accelerator.accumulate(model):\n",
" targets = batch['coarse_tokens'][:, 1:].contiguous()\n",
" \n",
" # Remove the last coarse token from the inputs since there is no target for it.\n",
" coarse_inputs = batch['coarse_tokens'][:, :-1]\n",
"\n",
" # Combine the semantic tokens and coarse tokens and feed them into the model.\n",
" inputs = torch.cat([batch['semantic_tokens'], coarse_inputs], dim=1)\n",
" logits = model(inputs, training=True)\n",
"\n",
" # We're only interested in the logits for the coarse tokens, so we ignore the logits for the input text tokens.\n",
" coarse_logits = logits[:, batch['semantic_tokens'].size(1):].contiguous()\n",
"\n",
" # Compute the loss.\n",
" loss = criterion(coarse_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n",
"\n",
" if semantic_cross_entropy_loss_weight > 0 and semantic_cross_entropy_loss_weight is not None:\n",
" semantic_logits = logits[:, :batch['semantic_tokens'].size(1)].contiguous()\n",
" semantic_loss = criterion(\n",
" semantic_logits.view(-1, model.config.input_vocab_size),\n",
" batch['semantic_tokens'].view(-1),\n",
" )\n",
" num_semantic_logits = semantic_logits.size(1)\n",
" num_coarse_logits = coarse_logits.size(1)\n",
" loss = (\n",
" semantic_loss * num_semantic_logits * semantic_cross_entropy_loss_weight +\n",
" loss * num_coarse_logits\n",
" ) / (num_semantic_logits + num_coarse_logits)\n",
"\n",
" accelerator.backward(loss)\n",
" if accelerator.sync_gradients:\n",
" params_to_clip = (\n",
" param for param in model.parameters() if param.requires_grad\n",
" )\n",
" accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
" if accelerator.sync_gradients:\n",
" progress_bar.update(1)\n",
" global_step += 1\n",
"\n",
" if global_step % checkpointing_steps == 0:\n",
" if accelerator.is_main_process:\n",
" save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
" accelerator.save_state(save_path)\n",
" logger.info(f\"Saved state to {save_path}\")\n",
"\n",
" logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
" progress_bar.set_postfix(**logs)\n",
" accelerator.log(logs, step=global_step)\n",
"\n",
" if global_step >= max_train_steps:\n",
" break\n",
" \n",
" accelerator.wait_for_everyone()\n",
"\n",
"if accelerator.is_main_process:\n",
" if lora_dim > 0:\n",
" model = convert_lora_to_linear_layer(model)\n",
" # save model\n",
" accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
" \n",
" config = model.config.__dict__\n",
" # save config\n",
" with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
" json.dump(config, f, indent=2)\n",
"\n",
"accelerator.end_training()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if accelerator.is_main_process:\n",
" model.eval()\n",
" validation_loss = 0.0\n",
" num_batches = 0\n",
" num_samples = 0\n",
" with torch.no_grad():\n",
" for val_step, val_batch in enumerate(validation_dataloader):\n",
" # Similar to training, process the validation batch\n",
" val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n",
" val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n",
" val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n",
"\n",
" # Forward pass for validation\n",
" val_logits = model(val_inputs, training=True)\n",
" val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n",
"\n",
" # Calculate the validation loss\n",
" val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
" validation_loss += val_loss.item()\n",
" num_batches += 1\n",
" num_samples += val_batch['semantic_tokens'].size(0)\n",
"\n",
" average_validation_loss = validation_loss / num_batches\n",
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

1141
train_fine.ipynb Normal file

File diff suppressed because it is too large Load Diff

899
train_semantic.ipynb Normal file
View File

@@ -0,0 +1,899 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import os\n",
"import re\n",
"import gc\n",
"import json\n",
"import math\n",
"import hashlib\n",
"import numpy as np\n",
"import logging\n",
"import torchaudio\n",
"from tqdm.auto import tqdm\n",
"import torch.nn.functional as F\n",
"from encodec.utils import convert_audio\n",
"from accelerate import Accelerator\n",
"from accelerate.utils import set_seed\n",
"from transformers import BertTokenizer\n",
"from huggingface_hub import hf_hub_download\n",
"from packaging import version\n",
"from diffusers.optimization import get_scheduler\n",
"\n",
"from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
"from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
"from bark.model import GPTConfig, GPT\n",
"from bark.model_fine import FineGPT, FineGPTConfig"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training Args"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_batch_size = 8\n",
"eval_batch_size = 8\n",
"grad_accum = 1\n",
"ckpt_path = 'models/text_2.pt'\n",
"model_type = \"text\"\n",
"dataset_path = 'datasets/joe_biden_state_of_union/'\n",
"logging_dir = 'logs/'\n",
"log_with = 'wandb'\n",
"hubert_path = 'data/models/hubert/hubert.pt'\n",
"hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
"\n",
"output_dir = 'semantic_output/'\n",
"resume_from_checkpoint = None\n",
"\n",
"checkpointing_steps = 1000\n",
"\n",
"mixed_precision = 'bf16'\n",
"bits = 16 #4 4 and 8 bit are a work in progress\n",
"compute_dtype = torch.bfloat16\n",
"double_quant = True\n",
"quant_type = 'nf4'\n",
"\n",
"lora_dim = 64\n",
"lora_scaling = 32\n",
"lora_dropout = 0.1\n",
"lora_module_name = 'transformer.h'\n",
"optimize_lora_params_only = True\n",
"\n",
"learning_rate = 1e-4\n",
"scale_lr = False\n",
"use_8bit_adam = False\n",
"adam_beta1 = 0.9\n",
"adam_beta2 = 0.999\n",
"adam_epsilon = 1e-8\n",
"weight_decay = 0.01\n",
"\n",
"llm_int8_skip_modules = None\n",
"keep_in_fp32_modules = ['lm_head']\n",
"\n",
"lr_scheduler_type = 'linear'\n",
"lr_warmup_steps = 200\n",
"num_train_epochs = 20\n",
"max_train_steps = None\n",
"max_grad_norm = 1.0\n",
"\n",
"seed = 741"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Define Functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"CONTEXT_WINDOW_SIZE = 1024\n",
"\n",
"MAX_TEXT_LEN = 256\n",
"\n",
"SEMANTIC_RATE_HZ = 49.9\n",
"SEMANTIC_VOCAB_SIZE = 10_000\n",
"\n",
"TEXT_ENCODING_OFFSET = 10_048\n",
"SEMANTIC_PAD_TOKEN = 10_000\n",
"TEXT_PAD_TOKEN = 129_595\n",
"SEMANTIC_INFER_TOKEN = 129_599\n",
"\n",
"MAX_SEMANTIC_LEN = 511\n",
"\n",
"SAMPLE_RATE = 24_000\n",
"CHANNELS = 1\n",
"\n",
"logger = logging.getLogger(__name__)\n",
"\n",
"\n",
"USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
"\n",
"default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
"CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
"\n",
"\n",
"def _clear_cuda_cache():\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.synchronize()\n",
"\n",
"\n",
"def _md5(fname):\n",
" hash_md5 = hashlib.md5()\n",
" with open(fname, \"rb\") as f:\n",
" for chunk in iter(lambda: f.read(4096), b\"\"):\n",
" hash_md5.update(chunk)\n",
" return hash_md5.hexdigest()\n",
"\n",
"\n",
"def _download(from_hf_path, file_name, to_local_path):\n",
" to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
" path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
" os.makedirs(path, exist_ok=True)\n",
" hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
" os.replace(os.path.join(path, file_name), to_local_path)\n",
"\n",
"\n",
"def _tokenize(tokenizer, text):\n",
" return tokenizer.encode(text, add_special_tokens=False)\n",
"\n",
"\n",
"def _detokenize(tokenizer, enc_text):\n",
" return tokenizer.decode(enc_text)\n",
"\n",
"\n",
"def _normalize_whitespace(text):\n",
" return re.sub(r\"\\s+\", \" \", text).strip()\n",
"\n",
"\n",
"REMOTE_MODEL_PATHS = {\n",
" \"text_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"text.pt\",\n",
" \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
" },\n",
" \"coarse_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"coarse.pt\",\n",
" \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
" },\n",
" \"fine_small\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"fine.pt\",\n",
" \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
" },\n",
" \"text\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"text_2.pt\",\n",
" \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
" },\n",
" \"coarse\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"coarse_2.pt\",\n",
" \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
" },\n",
" \"fine\": {\n",
" \"repo_id\": \"suno/bark\",\n",
" \"file_name\": \"fine_2.pt\",\n",
" \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
" },\n",
"}\n",
"\n",
"\n",
"def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
" if model_type == \"text\":\n",
" ConfigClass = GPTConfig\n",
" ModelClass = GPT\n",
" elif model_type == \"coarse\":\n",
" ConfigClass = GPTConfig\n",
" ModelClass = GPT\n",
" elif model_type == \"fine\":\n",
" ConfigClass = FineGPTConfig\n",
" ModelClass = FineGPT\n",
" else:\n",
" raise NotImplementedError()\n",
" model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
" model_info = REMOTE_MODEL_PATHS[model_key]\n",
" if ckpt_path in [None, '']:\n",
" ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
" if not os.path.exists(ckpt_path):\n",
" logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
" _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
" checkpoint = torch.load(ckpt_path, map_location=device)\n",
" # this is a hack\n",
" model_args = checkpoint[\"model_args\"]\n",
" if \"input_vocab_size\" not in model_args:\n",
" model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
" model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
" del model_args[\"vocab_size\"]\n",
" gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
" model = ModelClass(gptconf)\n",
" state_dict = checkpoint[\"model\"]\n",
" # fixup checkpoint\n",
" unwanted_prefix = \"_orig_mod.\"\n",
" for k, v in list(state_dict.items()):\n",
" if k.startswith(unwanted_prefix):\n",
" state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
" extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
" extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
" missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
" missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
" if len(extra_keys) != 0:\n",
" raise ValueError(f\"extra keys found: {extra_keys}\")\n",
" if len(missing_keys) != 0:\n",
" raise ValueError(f\"missing keys: {missing_keys}\")\n",
" model.load_state_dict(state_dict, strict=False)\n",
" n_params = model.get_num_params()\n",
" val_loss = checkpoint[\"best_val_loss\"].item()\n",
" print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
" del checkpoint, state_dict\n",
" _clear_cuda_cache()\n",
" if model_type == \"text\":\n",
" tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
" return model, tokenizer\n",
" return model\n",
"\n",
"\n",
"def load_filepaths_and_text(filename, split=\"|\"):\n",
" with open(filename, encoding='utf-8') as f:\n",
" filepaths_and_text = [line.strip().split(split) for line in f]\n",
" base = os.path.dirname(filename)\n",
" for j in range(len(filepaths_and_text)):\n",
" filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
" return filepaths_and_text\n",
"\n",
"class TtsDataset(torch.utils.data.Dataset):\n",
" def __init__(self, opt):\n",
" self.path = os.path.dirname(opt['path'])\n",
" self.mode = opt['mode']\n",
" self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '_valid.txt'))\n",
" self.tokenizer = opt['tokenizer']\n",
"\n",
" def __getitem__(self, index):\n",
" audiopath_and_text = self.audiopaths_and_text[index]\n",
" audiopath, text = audiopath_and_text[0], audiopath_and_text[1]\n",
"\n",
" input_ids = np.array(_tokenize(self.tokenizer, text)) + TEXT_ENCODING_OFFSET\n",
" input_ids = torch.from_numpy(input_ids).long()\n",
" tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
" semantic_tokens = tokens['semantic']\n",
" semantic_tokens = torch.from_numpy(semantic_tokens).long()\n",
"\n",
" return input_ids, semantic_tokens\n",
"\n",
" def __len__(self):\n",
" return len(self.audiopaths_and_text)\n",
"\n",
"\n",
"class TtsCollater():\n",
" def __init__(self):\n",
" pass\n",
" def __call__(self, batch):\n",
" max_text_len = MAX_TEXT_LEN\n",
" max_semantic_tokens_len = MAX_SEMANTIC_LEN\n",
" texts = []\n",
" semantic_tokens = []\n",
"\n",
" for b in batch:\n",
" text, semantic_tokens_ = b\n",
" text = F.pad(text, (0, max_text_len-len(text)), value=TEXT_PAD_TOKEN)\n",
" semantic_history = torch.from_numpy(np.array([SEMANTIC_PAD_TOKEN] * 256))\n",
" text = torch.cat([text, semantic_history, torch.tensor([SEMANTIC_INFER_TOKEN])])\n",
" texts.append(text)\n",
" semantic_tokens_ = semantic_tokens_[:max_semantic_tokens_len]\n",
" semantic_tokens.append(F.pad(semantic_tokens_, (0, max_semantic_tokens_len-len(semantic_tokens_)), value=SEMANTIC_PAD_TOKEN))\n",
"\n",
" return {\n",
" 'input_ids': torch.stack(texts).contiguous(),\n",
" 'semantic_tokens': torch.stack(semantic_tokens).contiguous()\n",
" }\n",
" \n",
"\n",
"accelerator = Accelerator(\n",
" gradient_accumulation_steps=grad_accum,\n",
" mixed_precision=mixed_precision,\n",
" log_with=log_with,\n",
" logging_dir=logging_dir,\n",
")\n",
"device = accelerator.device\n",
"\n",
"os.makedirs(output_dir, exist_ok=True)\n",
"\n",
"set_seed(seed)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup Dataset (only need to do this once)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
"\n",
"path = dataset_path\n",
"\n",
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
"from hubert.hubert_manager import HuBERTManager\n",
"hubert_manager = HuBERTManager()\n",
"from hubert.pre_kmeans_hubert import CustomHubert\n",
"from hubert.customtokenizer import CustomTokenizer\n",
"hubert_manager.make_sure_hubert_installed()\n",
"hubert_manager.make_sure_tokenizer_installed()\n",
"\n",
"# Load the HuBERT model\n",
"hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
"hubert_model.eval()\n",
"for param in hubert_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"# Load the CustomTokenizer model\n",
"hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
"\n",
"from bark.generation import load_codec_model\n",
"codec_model = load_codec_model(use_gpu=True)\n",
"codec_model.eval()\n",
"for param in codec_model.parameters():\n",
" param.requires_grad = False\n",
"\n",
"\n",
"def get_duration(wav, sr):\n",
" return wav.shape[1] / sr\n",
"\n",
"valid_lines_train = []\n",
"# convert wavs to semantic tokens\n",
"for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
" wav, sr = torchaudio.load(wav_path)\n",
" if not get_duration(wav, sr) > max_duration_sec:\n",
" valid_lines_train.append((wav_path, txt))\n",
" wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
"\n",
" semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
" semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
"\n",
" # save semantic tokens\n",
" os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
" semantic_tokens = semantic_tokens.cpu().numpy()\n",
"\n",
" # Extract discrete codes from EnCodec\n",
" with torch.no_grad():\n",
" encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
" codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
"\n",
" # move codes to cpu\n",
" codes = codes.cpu().numpy()\n",
"\n",
" # save tokens\n",
" np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
"\n",
"# rewrite train.txt with valid lines\n",
"with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
" for wav_path, txt in valid_lines_train:\n",
" wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
" f.write(f'{wav_path}|{txt}\\n')\n",
"\n",
"valid_lines_valid = []\n",
"for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
" wav, sr = torchaudio.load(wav_path)\n",
" if not get_duration(wav, sr) > max_duration_sec:\n",
" valid_lines_valid.append((wav_path, txt))\n",
" wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
"\n",
" semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
" semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
"\n",
" # save semantic tokens\n",
" os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
" semantic_tokens = semantic_tokens.cpu().numpy()\n",
" \n",
" # Extract discrete codes from EnCodec\n",
" with torch.no_grad():\n",
" encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
" codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
"\n",
" # move codes to cpu\n",
" codes = codes.cpu().numpy()\n",
"\n",
" # save tokens\n",
" np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
"\n",
"# rewrite valid.txt with valid lines\n",
"with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
" for wav_path, txt in valid_lines_valid:\n",
" wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
" f.write(f'{wav_path}|{txt}\\n')\n",
"\n",
"del hubert_model\n",
"del hubert_tokenizer\n",
"del codec_model\n",
"gc.collect()\n",
"torch.cuda.empty_cache()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, tokenizer = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if scale_lr:\n",
" learning_rate = (\n",
" learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
" )\n",
"\n",
"if use_8bit_adam:\n",
" try:\n",
" import bitsandbytes as bnb\n",
" except ImportError:\n",
" raise ImportError(\n",
" \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
" )\n",
"\n",
" optimizer_class = bnb.optim.AdamW8bit\n",
"else:\n",
" optimizer_class = torch.optim.AdamW"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"quantization_config=BitsAndBytesConfig(\n",
" load_in_4bit=bits == 4,\n",
" load_in_8bit=bits == 8,\n",
" llm_int8_threshold=6.0,\n",
" llm_int8_has_fp16_weight=False,\n",
" bnb_4bit_compute_dtype=compute_dtype,\n",
" bnb_4bit_use_double_quant=double_quant,\n",
" bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
")\n",
"\n",
"# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
"# if quantization_config.load_in_8bit:\n",
"# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
"# elif quantization_config.load_in_4bit:\n",
"# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
"\n",
"# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
"# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
"# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
"# else:\n",
"# modules_to_not_convert = llm_int8_skip_modules\n",
"\n",
"# if not isinstance(modules_to_not_convert, list):\n",
"# modules_to_not_convert = [modules_to_not_convert]\n",
"\n",
"# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
"\n",
"# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
"\n",
"# if quantization_config.load_in_4bit and not supports_4bit:\n",
"# raise ValueError(\n",
"# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
"# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
"# )\n",
" \n",
"# if len(modules_to_not_convert) == 0:\n",
"# modules_to_not_convert = None\n",
"\n",
"# model = replace_with_bnb_linear(\n",
"# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
"# )\n",
"\n",
"# # training in 8-bit is only available in 0.37.0+\n",
"# model._is_kbit_training_enabled = version.parse(\n",
"# importlib_metadata.version(\"bitsandbytes\")\n",
"# ) >= version.parse(\"0.37.0\")\n",
"\n",
"# model.config.quantization_config = quantization_config"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if bits == 4:\n",
" from accelerate.utils import CustomDtype\n",
" target_dtype = CustomDtype.INT4\n",
"elif bits == 8:\n",
" target_dtype = torch.int8\n",
"\n",
"if lora_dim > 0:\n",
" for param in model.parameters():\n",
" if param.ndim == 1:\n",
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
" param.data = param.data.to(torch.float32)\n",
" \n",
" class CastOutputToFloat(nn.Sequential):\n",
" def forward(self, x):\n",
" return super().forward(x).to(torch.float32)\n",
"\n",
" model.lm_head = CastOutputToFloat(model.lm_head)\n",
"\n",
" model = convert_linear_layer_to_lora(model, lora_module_name,\n",
" lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
" lora_dropout=lora_dropout)\n",
" if optimize_lora_params_only:\n",
" model = only_optimize_lora_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"params_to_optimize = (\n",
" param for param in model.parameters() if param.requires_grad\n",
" )\n",
"\n",
"optimizer = optimizer_class(\n",
" params_to_optimize,\n",
" lr=learning_rate,\n",
" betas=(adam_beta1, adam_beta2),\n",
" weight_decay=weight_decay,\n",
" eps=adam_epsilon,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"opt_train = {\n",
" 'path': dataset_path,\n",
" 'tokenizer': tokenizer,\n",
" 'mode': 'train',\n",
"}\n",
"\n",
"opt_val = {\n",
" 'path': dataset_path,\n",
" 'tokenizer': tokenizer,\n",
" 'mode': 'valid',\n",
"}\n",
"\n",
"train_dataset = TtsDataset(opt_train)\n",
"validation_dataset = TtsDataset(opt_val)\n",
"\n",
"train_dataloader = torch.utils.data.DataLoader(\n",
" train_dataset,\n",
" batch_size=train_batch_size,\n",
" collate_fn=TtsCollater(),\n",
")\n",
"\n",
"validation_dataloader = torch.utils.data.DataLoader(\n",
" validation_dataset,\n",
" batch_size=eval_batch_size,\n",
" collate_fn=TtsCollater(),\n",
")\n",
"\n",
"criterion = torch.nn.CrossEntropyLoss(ignore_index=SEMANTIC_PAD_TOKEN)\n",
"\n",
"# Scheduler and math around the number of training steps.\n",
"overrode_max_train_steps = False\n",
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
"if max_train_steps is None:\n",
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
" overrode_max_train_steps = True\n",
"\n",
"lr_scheduler = get_scheduler(\n",
" lr_scheduler_type,\n",
" optimizer=optimizer,\n",
" num_warmup_steps=lr_warmup_steps * grad_accum,\n",
" num_training_steps=max_train_steps * grad_accum,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
" model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
")\n",
"accelerator.register_for_checkpointing(lr_scheduler)\n",
"\n",
"weight_dtype = torch.float32\n",
"if accelerator.mixed_precision == \"fp16\":\n",
" weight_dtype = torch.float16\n",
"elif accelerator.mixed_precision == \"bf16\":\n",
" weight_dtype = torch.bfloat16"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
"if overrode_max_train_steps:\n",
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
"# Afterwards we recalculate our number of training epochs\n",
"num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
"\n",
"# We need to initialize the trackers we use, and also store our configuration.\n",
"# The trackers initializes automatically on the main process.\n",
"if accelerator.is_main_process:\n",
" accelerator.init_trackers(\"bark_semantic\", config={})\n",
"\n",
"# Train!\n",
"total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
"logger.info(\"***** Running training *****\")\n",
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
"logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
"logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
"logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
"logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
"logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
"logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
"global_step = 0\n",
"first_epoch = 0\n",
"\n",
"if resume_from_checkpoint:\n",
" if resume_from_checkpoint != \"latest\":\n",
" path = os.path.basename(resume_from_checkpoint)\n",
" else:\n",
" # Get the most recent checkpoint\n",
" dirs = os.listdir(output_dir)\n",
" dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
" dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
" path = dirs[-1]\n",
" accelerator.print(f\"Resuming from checkpoint {path}\")\n",
" accelerator.load_state(os.path.join(output_dir, path))\n",
" global_step = int(path.split(\"-\")[1])\n",
"\n",
" resume_global_step = global_step * grad_accum\n",
" first_epoch = resume_global_step // num_update_steps_per_epoch\n",
" resume_step = resume_global_step % num_update_steps_per_epoch\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if accelerator.is_main_process:\n",
" model.eval()\n",
" validation_loss = 0.0\n",
" num_batches = 0\n",
" num_samples = 0\n",
" with torch.no_grad():\n",
" for val_step, val_batch in enumerate(validation_dataloader):\n",
" # Similar to training, process the validation batch\n",
" val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
" val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
" val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
"\n",
" # Forward pass for validation\n",
" val_logits = model(val_inputs, training=True)\n",
" val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
"\n",
" # Calculate the validation loss\n",
" val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
" validation_loss += val_loss.item()\n",
" num_batches += 1\n",
" num_samples += val_batch['input_ids'].size(0)\n",
"\n",
" average_validation_loss = validation_loss / num_batches\n",
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Only show the progress bar once on each machine.\n",
"progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
"progress_bar.set_description(\"Steps\")\n",
"\n",
"for epoch in range(first_epoch, num_train_epochs):\n",
" model.train()\n",
" for step, batch in enumerate(train_dataloader):\n",
" # Skip steps until we reach the resumed step\n",
" if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
" if step % grad_accum == 0:\n",
" progress_bar.update(1)\n",
" continue\n",
"\n",
" with accelerator.accumulate(model):\n",
" targets = batch['semantic_tokens'][:, 1:].contiguous()\n",
" \n",
" # Remove the last semantic token from the inputs since there is no target for it.\n",
" semantic_inputs = batch['semantic_tokens'][:, :-1]\n",
"\n",
" # Combine the text and semantic tokens and feed them into the model.\n",
" inputs = torch.cat([batch['input_ids'], semantic_inputs], dim=1)\n",
" logits = model(inputs, training=True)\n",
"\n",
" # We're only interested in the logits for the semantic tokens, so we ignore the logits for the input text tokens.\n",
" semantic_logits = logits[:, batch['input_ids'].size(1):].contiguous()\n",
"\n",
" # Compute the loss.\n",
" loss = criterion(semantic_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n",
"\n",
" accelerator.backward(loss)\n",
" if accelerator.sync_gradients:\n",
" params_to_clip = (\n",
" param for param in model.parameters() if param.requires_grad\n",
" )\n",
" accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
" optimizer.step()\n",
" lr_scheduler.step()\n",
" optimizer.zero_grad()\n",
"\n",
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
" if accelerator.sync_gradients:\n",
" progress_bar.update(1)\n",
" global_step += 1\n",
"\n",
" if global_step % checkpointing_steps == 0:\n",
" if accelerator.is_main_process:\n",
" save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
" accelerator.save_state(save_path)\n",
" logger.info(f\"Saved state to {save_path}\")\n",
"\n",
" logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
" progress_bar.set_postfix(**logs)\n",
" accelerator.log(logs, step=global_step)\n",
"\n",
" if global_step >= max_train_steps:\n",
" break\n",
" \n",
" accelerator.wait_for_everyone()\n",
"\n",
"if accelerator.is_main_process:\n",
" if lora_dim > 0:\n",
" model = convert_lora_to_linear_layer(model)\n",
" # save model\n",
" accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
"\n",
" config = model.config.__dict__\n",
" # save config\n",
" with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
" json.dump(config, f, indent=2)\n",
"\n",
"accelerator.end_training()"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# Validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if accelerator.is_main_process:\n",
" model.eval()\n",
" validation_loss = 0.0\n",
" num_batches = 0\n",
" num_samples = 0\n",
" with torch.no_grad():\n",
" for val_step, val_batch in enumerate(validation_dataloader):\n",
" # Similar to training, process the validation batch\n",
" val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
" val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
" val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
"\n",
" # Forward pass for validation\n",
" val_logits = model(val_inputs, training=True)\n",
" val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
"\n",
" # Calculate the validation loss\n",
" val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
" validation_loss += val_loss.item()\n",
" num_batches += 1\n",
" num_samples += val_batch['input_ids'].size(0)\n",
"\n",
" average_validation_loss = validation_loss / num_batches\n",
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.8"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -11,7 +11,7 @@ class LinearLayer_LoRA(nn.Module):
weight,
lora_dim=0,
lora_scaling=1,
lora_droppout=0,
lora_dropout=0,
bias=None):
super(LinearLayer_LoRA, self).__init__()
self.weight = weight
@@ -29,8 +29,8 @@ class LinearLayer_LoRA(nn.Module):
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
self.lora_scaling = lora_scaling / lora_dim
if lora_droppout > 0:
self.lora_dropout = nn.Dropout(lora_droppout)
if lora_dropout > 0:
self.lora_dropout = nn.Dropout(lora_dropout)
else:
self.lora_dropout = nn.Identity()
@@ -116,15 +116,15 @@ def convert_linear_layer_to_lora(model,
part_module_name,
lora_dim=0,
lora_scaling=1,
lora_droppout=0):
repalce_name = []
lora_dropout=0):
replace_name = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and part_module_name in name:
repalce_name.append(name)
for name in repalce_name:
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
tmp = LinearLayer_LoRA(
module.weight, lora_dim, lora_scaling, lora_droppout,
module.weight, lora_dim, lora_scaling, lora_dropout,
module.bias).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
return model
@@ -132,11 +132,11 @@ def convert_linear_layer_to_lora(model,
# convert the LoRA layer to linear layer
def convert_lora_to_linear_layer(model):
repalce_name = []
replace_name = []
for name, module in model.named_modules():
if isinstance(module, LinearLayer_LoRA):
repalce_name.append(name)
for name in repalce_name:
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
module.fuse_lora_weight()
return model