diff --git a/.gitignore b/.gitignore index f52000b..ff7a9ed 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,7 @@ __pycache__/ *.wav _temp/ models/ -output.npz \ No newline at end of file +wandb/ +*_output/ +output.npz +joe_biden_state_of_union/ \ No newline at end of file diff --git a/bark/generation.py b/bark/generation.py index 74fcc0d..69f166f 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -3,6 +3,7 @@ import gc import hashlib import os import re +import json from encodec import EncodecModel import funcy @@ -203,42 +204,81 @@ def _load_model(ckpt_path, device, use_small=False, model_type="text"): raise NotImplementedError() model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type model_info = REMOTE_MODEL_PATHS[model_key] - if ( - os.path.exists(ckpt_path) and - _md5(ckpt_path) != model_info["checksum"] - ): - logger.warning(f"found outdated {model_type} model, removing.") - os.remove(ckpt_path) + # if ( + # os.path.exists(ckpt_path) and + # _md5(ckpt_path) != model_info["checksum"] + # ): + # logger.warning(f"found outdated {model_type} model, removing.") + # os.remove(ckpt_path) if not os.path.exists(ckpt_path): logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.") _download(model_info["repo_id"], model_info["file_name"], ckpt_path) checkpoint = torch.load(ckpt_path, map_location=device) # this is a hack - model_args = checkpoint["model_args"] + # check if config.json is in the same directory as the checkpoint + # if so, load it + # otherwise, assume it's in the checkpoint + config_path = os.path.join(os.path.dirname(ckpt_path), "config.json") + if os.path.exists(config_path): + with open(config_path, "r") as f: + model_args = json.load(f) + else: + model_args = checkpoint["model_args"] if "input_vocab_size" not in model_args: model_args["input_vocab_size"] = model_args["vocab_size"] model_args["output_vocab_size"] = model_args["vocab_size"] del model_args["vocab_size"] - gptconf = ConfigClass(**checkpoint["model_args"]) + gptconf = ConfigClass(**model_args) model = ModelClass(gptconf) - state_dict = checkpoint["model"] + if checkpoint.get("model", None) is not None: + state_dict = checkpoint["model"] + else: + state_dict = checkpoint # fixup checkpoint unwanted_prefix = "_orig_mod." for k, v in list(state_dict.items()): if k.startswith(unwanted_prefix): state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k) + unwanted_suffixes = [ + "lora_right_weight", + "lora_left_weight", + "lora_right_bias", + "lora_left_bias", + ] + for k, v in list(state_dict.items()): + for suffix in unwanted_suffixes: + if k.endswith(suffix): + state_dict.pop(k) + # super hacky - should probably refactor this + if state_dict.get('lm_head.0.weight', None) is not None: + state_dict['lm_head.weight'] = state_dict.pop('lm_head.0.weight') + if state_dict.get('lm_heads.0.0.weight', None) is not None: + state_dict['lm_heads.0.weight'] = state_dict.pop('lm_heads.0.0.weight') + if state_dict.get('lm_heads.1.0.weight', None) is not None: + state_dict['lm_heads.1.weight'] = state_dict.pop('lm_heads.1.0.weight') + if state_dict.get('lm_heads.2.0.weight', None) is not None: + state_dict['lm_heads.2.weight'] = state_dict.pop('lm_heads.2.0.weight') + if state_dict.get('lm_heads.3.0.weight', None) is not None: + state_dict['lm_heads.3.weight'] = state_dict.pop('lm_heads.3.0.weight') + if state_dict.get('lm_heads.4.0.weight', None) is not None: + state_dict['lm_heads.4.weight'] = state_dict.pop('lm_heads.4.0.weight') + if state_dict.get('lm_heads.5.0.weight', None) is not None: + state_dict['lm_heads.5.weight'] = state_dict.pop('lm_heads.5.0.weight') + if state_dict.get('lm_heads.6.0.weight', None) is not None: + state_dict['lm_heads.6.weight'] = state_dict.pop('lm_heads.6.0.weight') extra_keys = set(state_dict.keys()) - set(model.state_dict().keys()) extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")]) missing_keys = set(model.state_dict().keys()) - set(state_dict.keys()) missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")]) if len(extra_keys) != 0: - raise ValueError(f"extra keys found: {extra_keys}") + print(f"extra keys found: {extra_keys}") if len(missing_keys) != 0: raise ValueError(f"missing keys: {missing_keys}") model.load_state_dict(state_dict, strict=False) n_params = model.get_num_params() - val_loss = checkpoint["best_val_loss"].item() - logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") + if checkpoint.get("best_val_loss", None) is not None: + val_loss = checkpoint["best_val_loss"].item() + logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss") model.eval() model.to(device) del checkpoint, state_dict @@ -273,8 +313,11 @@ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="te models_devices[model_key] = device device = "cpu" if model_key not in models or force_reload: - ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path) - clean_models(model_key=model_key) + if path.endswith(".ckpt") or path.endswith(".pt") or path.endswith(".bin"): + ckpt_path = path + else: + ckpt_path = _get_ckpt_path(model_type, use_small=use_small, path=path) + # clean_models(model_key=model_key) model = _load_model_f(ckpt_path, device) models[model_key] = model if model_type == "text": @@ -306,10 +349,13 @@ def load_codec_model(use_gpu=True, force_reload=False): def preload_models( text_use_gpu=True, text_use_small=False, + text_model_path=None, coarse_use_gpu=True, coarse_use_small=False, + coarse_model_path=None, fine_use_gpu=True, fine_use_small=False, + fine_model_path=None, codec_use_gpu=True, force_reload=False, path=None, @@ -320,17 +366,17 @@ def preload_models( ): logger.warning("No GPU being used. Careful, inference might be very slow!") _ = load_model( - model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path + model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload, path=path if text_model_path is None else text_model_path ) _ = load_model( model_type="coarse", use_gpu=coarse_use_gpu, use_small=coarse_use_small, force_reload=force_reload, - path=path, + path=path if coarse_model_path is None else coarse_model_path, ) _ = load_model( - model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path + model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload, path=path if fine_model_path is None else fine_model_path ) _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload) diff --git a/bark/model.py b/bark/model.py index b87e534..843b8ad 100644 --- a/bark/model.py +++ b/bark/model.py @@ -9,7 +9,7 @@ import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange, repeat, reduce -SEMANTIC_PAD_TOKEN = 10_000 + class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -167,7 +167,7 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, training=False): device = idx.device b, t = idx.size() if past_kv is not None: @@ -215,19 +215,9 @@ class GPT(nn.Module): x = self.transformer.ln_f(x) - if labels is not None: + if training: logits = self.lm_head(x) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - # Flatten the tokens - loss_fct = nn.CrossEntropyLoss() - shift_logits = shift_logits.view(-1, self.config.output_vocab_size) - shift_labels = shift_labels.view(-1) - # Enable model parallelism - shift_labels = shift_labels.to(shift_logits.device) - loss = loss_fct(shift_logits, shift_labels) - return logits, loss + return logits # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim diff --git a/datasets/.tmp b/datasets/.tmp new file mode 100644 index 0000000..e69de29 diff --git a/generate.ipynb b/generate.ipynb index a135af3..70b4f8f 100644 --- a/generate.ipynb +++ b/generate.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": null, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -20,7 +20,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ diff --git a/test_models.ipynb b/test_models.ipynb new file mode 100644 index 0000000..2fbb88a --- /dev/null +++ b/test_models.ipynb @@ -0,0 +1,183 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from bark.api import generate_audio\n", + "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "semantic_path = \"E:/Python/bark-with-voice-clone/semantic_output/pytorch_model.bin\"\n", + "coarse_path = \"E:/Python/bark-with-voice-clone/coarse_output/pytorch_model.bin\"\n", + "fine_path = \"E:/Python/bark-with-voice-clone/fine_output/pytorch_model.bin\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "preload_models(\n", + " text_use_gpu=True,\n", + " text_use_small=False,\n", + " text_model_path=semantic_path,\n", + " coarse_use_gpu=True,\n", + " coarse_use_small=False,\n", + " coarse_model_path=coarse_path,\n", + " fine_use_gpu=True,\n", + " fine_use_small=False,\n", + " fine_model_path=fine_path,\n", + " codec_use_gpu=True,\n", + " force_reload=False,\n", + " path=\"models\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# simple generation\n", + "text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! [laughs] A lot better than the original!\"\n", + "audio_array = generate_audio(text_prompt, history_prompt=None, text_temp=0.7, waveform_temp=0.7)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from IPython.display import Audio\n", + "# play audio\n", + "Audio(audio_array, rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.io.wavfile import write as write_wav\n", + "# save audio\n", + "filepath = \"output/audio.wav\" # change this to your desired output path\n", + "write_wav(filepath, SAMPLE_RATE, audio_array)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, semantic_top_p=0.95, coarse_temp=0.7, coarse_top_k=50, coarse_top_p=0.95, fine_temp=0.5, voice_name=None, use_semantic_history_prompt=True, use_coarse_history_prompt=True, use_fine_history_prompt=True, output_full=False):\n", + " # generation with more control\n", + " x_semantic = generate_text_semantic(\n", + " text_prompt,\n", + " history_prompt=voice_name if use_semantic_history_prompt else None,\n", + " temp=semantic_temp,\n", + " top_k=semantic_top_k,\n", + " top_p=semantic_top_p,\n", + " )\n", + "\n", + " x_coarse_gen = generate_coarse(\n", + " x_semantic,\n", + " history_prompt=voice_name if use_coarse_history_prompt else None,\n", + " temp=coarse_temp,\n", + " top_k=coarse_top_k,\n", + " top_p=coarse_top_p,\n", + " )\n", + " x_fine_gen = generate_fine(\n", + " x_coarse_gen,\n", + " history_prompt=voice_name if use_fine_history_prompt else None,\n", + " temp=fine_temp,\n", + " )\n", + "\n", + " if output_full:\n", + " full_generation = {\n", + " 'semantic_prompt': x_semantic,\n", + " 'coarse_prompt': x_coarse_gen,\n", + " 'fine_prompt': x_fine_gen,\n", + " }\n", + " return full_generation, codec_decode(x_fine_gen)\n", + " return codec_decode(x_fine_gen)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! [laughs] A lot better than the original!\"\n", + "\n", + "audio_array = generate_with_settings(\n", + " text_prompt,\n", + " semantic_temp=0.7,\n", + " semantic_top_k=50,\n", + " semantic_top_p=0.99,\n", + " coarse_temp=0.7,\n", + " coarse_top_k=50,\n", + " coarse_top_p=0.99,\n", + " fine_temp=0.5,\n", + " voice_name=None,\n", + " use_semantic_history_prompt=True,\n", + " use_coarse_history_prompt=True,\n", + " use_fine_history_prompt=True,\n", + " output_full=False\n", + ")\n", + "\n", + "from IPython.display import Audio\n", + "# play audio\n", + "Audio(audio_array, rate=SAMPLE_RATE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.io.wavfile import write as write_wav\n", + "# save audio\n", + "filepath = \"output/audio.wav\" # change this to your desired output path\n", + "write_wav(filepath, SAMPLE_RATE, audio_array)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/train_coarse.ipynb b/train_coarse.ipynb new file mode 100644 index 0000000..e53e0c1 --- /dev/null +++ b/train_coarse.ipynb @@ -0,0 +1,936 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import os\n", + "import re\n", + "import gc\n", + "import math\n", + "import json\n", + "import hashlib\n", + "import numpy as np\n", + "import logging\n", + "import torchaudio\n", + "from tqdm.auto import tqdm\n", + "import torch.nn.functional as F\n", + "from encodec.utils import convert_audio\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import set_seed\n", + "from transformers import BertTokenizer\n", + "from huggingface_hub import hf_hub_download\n", + "from packaging import version\n", + "from diffusers.optimization import get_scheduler\n", + "\n", + "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n", + "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n", + "from bark.model import GPTConfig, GPT\n", + "from bark.model_fine import FineGPT, FineGPTConfig" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training Args" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_batch_size = 8\n", + "eval_batch_size = 8\n", + "grad_accum = 1\n", + "ckpt_path = 'models/coarse_2.pt'\n", + "model_type = \"coarse\"\n", + "dataset_path = 'datasets/joe_biden_state_of_union/'\n", + "logging_dir = 'logs/'\n", + "log_with = 'wandb'\n", + "hubert_path = 'data/models/hubert/hubert.pt'\n", + "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n", + "\n", + "output_dir = 'coarse_output/'\n", + "resume_from_checkpoint = None\n", + "\n", + "checkpointing_steps = 1000\n", + "\n", + "mixed_precision = 'bf16'\n", + "bits = 16 #4 4 and 8 bit are a work in progress\n", + "compute_dtype = torch.bfloat16\n", + "double_quant = True\n", + "quant_type = 'nf4'\n", + "\n", + "lora_dim = 64\n", + "lora_scaling = 1\n", + "lora_dropout = 0.1\n", + "lora_module_name = 'transformer.h'\n", + "optimize_lora_params_only = True\n", + "\n", + "learning_rate = 1e-4\n", + "scale_lr = False\n", + "use_8bit_adam = False\n", + "adam_beta1 = 0.9\n", + "adam_beta2 = 0.999\n", + "adam_epsilon = 1e-8\n", + "weight_decay = 0.01\n", + "\n", + "llm_int8_skip_modules = None\n", + "keep_in_fp32_modules = ['lm_head']\n", + "\n", + "lr_scheduler_type = 'linear'\n", + "lr_warmup_steps = 200\n", + "num_train_epochs = 20\n", + "max_train_steps = None\n", + "max_grad_norm = 1.0\n", + "\n", + "semantic_cross_entropy_loss_weight = 0\n", + "\n", + "seed = 741" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT_WINDOW_SIZE = 1024\n", + "\n", + "MAX_SEMANTIC_LEN = 256\n", + "\n", + "SEMANTIC_RATE_HZ = 49.9\n", + "SEMANTIC_VOCAB_SIZE = 10_000\n", + "\n", + "TEXT_ENCODING_OFFSET = 10_048\n", + "SEMANTIC_PAD_TOKEN = 10_000\n", + "TEXT_PAD_TOKEN = 129_595\n", + "SEMANTIC_INFER_TOKEN = 129_599\n", + "\n", + "MAX_COARSE_LEN = 768\n", + "\n", + "SAMPLE_RATE = 24_000\n", + "CHANNELS = 1\n", + "\n", + "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n", + "COARSE_INFER_TOKEN = 12_050\n", + "\n", + "CODEBOOK_SIZE = 1024\n", + "N_COARSE_CODEBOOKS = 2\n", + "N_FINE_CODEBOOKS = 8\n", + "COARSE_RATE_HZ = 75\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n", + "\n", + "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n", + "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n", + "\n", + "\n", + "def _clear_cuda_cache():\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "def _md5(fname):\n", + " hash_md5 = hashlib.md5()\n", + " with open(fname, \"rb\") as f:\n", + " for chunk in iter(lambda: f.read(4096), b\"\"):\n", + " hash_md5.update(chunk)\n", + " return hash_md5.hexdigest()\n", + "\n", + "\n", + "def _download(from_hf_path, file_name, to_local_path):\n", + " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n", + " path = '/'.join(to_local_path.split(\"/\")[:-1])\n", + " os.makedirs(path, exist_ok=True)\n", + " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n", + " os.replace(os.path.join(path, file_name), to_local_path)\n", + "\n", + "\n", + "def _tokenize(tokenizer, text):\n", + " return tokenizer.encode(text, add_special_tokens=False)\n", + "\n", + "\n", + "def _detokenize(tokenizer, enc_text):\n", + " return tokenizer.decode(enc_text)\n", + "\n", + "\n", + "def _normalize_whitespace(text):\n", + " return re.sub(r\"\\s+\", \" \", text).strip()\n", + "\n", + "\n", + "REMOTE_MODEL_PATHS = {\n", + " \"text_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text.pt\",\n", + " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n", + " },\n", + " \"coarse_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse.pt\",\n", + " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n", + " },\n", + " \"fine_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine.pt\",\n", + " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n", + " },\n", + " \"text\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text_2.pt\",\n", + " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n", + " },\n", + " \"coarse\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse_2.pt\",\n", + " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n", + " },\n", + " \"fine\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine_2.pt\",\n", + " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n", + " },\n", + "}\n", + "\n", + "\n", + "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n", + " if model_type == \"text\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"coarse\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"fine\":\n", + " ConfigClass = FineGPTConfig\n", + " ModelClass = FineGPT\n", + " else:\n", + " raise NotImplementedError()\n", + " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n", + " model_info = REMOTE_MODEL_PATHS[model_key]\n", + " if ckpt_path in [None, '']:\n", + " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n", + " if not os.path.exists(ckpt_path):\n", + " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n", + " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " # this is a hack\n", + " model_args = checkpoint[\"model_args\"]\n", + " if \"input_vocab_size\" not in model_args:\n", + " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n", + " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n", + " del model_args[\"vocab_size\"]\n", + " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n", + " model = ModelClass(gptconf)\n", + " state_dict = checkpoint[\"model\"]\n", + " # fixup checkpoint\n", + " unwanted_prefix = \"_orig_mod.\"\n", + " for k, v in list(state_dict.items()):\n", + " if k.startswith(unwanted_prefix):\n", + " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n", + " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n", + " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n", + " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n", + " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n", + " if len(extra_keys) != 0:\n", + " raise ValueError(f\"extra keys found: {extra_keys}\")\n", + " if len(missing_keys) != 0:\n", + " raise ValueError(f\"missing keys: {missing_keys}\")\n", + " model.load_state_dict(state_dict, strict=False)\n", + " n_params = model.get_num_params()\n", + " val_loss = checkpoint[\"best_val_loss\"].item()\n", + " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n", + " del checkpoint, state_dict\n", + " _clear_cuda_cache()\n", + " if model_type == \"text\":\n", + " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n", + " return model, tokenizer\n", + " return model\n", + "\n", + "\n", + "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n", + " assert len(arr.shape) == 2\n", + " arr = arr.copy()\n", + " if offset_size is not None:\n", + " for n in range(1, arr.shape[0]):\n", + " arr[n, :] += offset_size * n\n", + " flat_arr = arr.ravel(\"F\")\n", + " return flat_arr\n", + "\n", + "\n", + "def load_filepaths_and_text(filename, split=\"|\"):\n", + " with open(filename, encoding='utf-8') as f:\n", + " filepaths_and_text = [line.strip().split(split) for line in f]\n", + " base = os.path.dirname(filename)\n", + " for j in range(len(filepaths_and_text)):\n", + " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n", + " return filepaths_and_text\n", + "\n", + "\n", + "class TtsDataset(torch.utils.data.Dataset):\n", + " def __init__(self, opt):\n", + " self.path = os.path.dirname(opt['path'])\n", + " self.mode = opt['mode']\n", + " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n", + "\n", + " def __getitem__(self, index):\n", + " audiopath_and_text = self.audiopaths_and_text[index]\n", + " audiopath = audiopath_and_text[0]\n", + "\n", + " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n", + " semantic_tokens = tokens['semantic']\n", + " coarse_tokens = _flatten_codebooks(tokens['coarse'], offset_size=CODEBOOK_SIZE) + SEMANTIC_VOCAB_SIZE\n", + "\n", + " return torch.from_numpy(semantic_tokens), torch.from_numpy(coarse_tokens)\n", + "\n", + " def __len__(self):\n", + " return len(self.audiopaths_and_text)\n", + "\n", + "\n", + "class TtsCollater():\n", + " def __init__(self):\n", + " pass\n", + " def __call__(self, batch):\n", + " max_semantic_len = MAX_SEMANTIC_LEN\n", + " max_coarse_len = MAX_COARSE_LEN\n", + " semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS\n", + " semantic_tokens = []\n", + " coarse_tokens = []\n", + "\n", + " for b in batch:\n", + " semantic_tokens_, coarse_tokens_ = b\n", + " start_idx = None\n", + " if len(semantic_tokens_) > max_semantic_len:\n", + " start_idx = np.random.randint(0, len(semantic_tokens_) - max_semantic_len + 1)\n", + " semantic_tokens_ = semantic_tokens_[start_idx:start_idx+max_semantic_len]\n", + " semantic_tokens_ = F.pad(semantic_tokens_, (0, max_semantic_len-len(semantic_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n", + " semantic_tokens_ = torch.cat([semantic_tokens_, torch.tensor([COARSE_INFER_TOKEN])])\n", + " semantic_tokens.append(semantic_tokens_)\n", + "\n", + " if start_idx is not None:\n", + " start_idx_coarse = int(start_idx * semantic_to_coarse_ratio) \n", + " coarse_tokens_ = coarse_tokens_[start_idx_coarse:start_idx_coarse+max_coarse_len]\n", + " coarse_tokens_ = F.pad(coarse_tokens_, (0, max_coarse_len-len(coarse_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n", + " coarse_tokens.append(coarse_tokens_)\n", + "\n", + " return {\n", + " 'semantic_tokens': torch.stack(semantic_tokens).contiguous(),\n", + " 'coarse_tokens': torch.stack(coarse_tokens).contiguous()\n", + " }\n", + " \n", + "\n", + "accelerator = Accelerator(\n", + " gradient_accumulation_steps=grad_accum,\n", + " mixed_precision=mixed_precision,\n", + " log_with=log_with,\n", + " logging_dir=logging_dir,\n", + ")\n", + "device = accelerator.device\n", + "\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "set_seed(seed)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup Dataset (only need to do this once)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n", + "\n", + "# path = dataset_path\n", + "\n", + "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", + "# from hubert.hubert_manager import HuBERTManager\n", + "# hubert_manager = HuBERTManager()\n", + "# from hubert.pre_kmeans_hubert import CustomHubert\n", + "# from hubert.customtokenizer import CustomTokenizer\n", + "# hubert_manager.make_sure_hubert_installed()\n", + "# hubert_manager.make_sure_tokenizer_installed()\n", + "\n", + "# # Load the HuBERT model\n", + "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n", + "# hubert_model.eval()\n", + "# for param in hubert_model.parameters():\n", + "# param.requires_grad = False\n", + "\n", + "# # Load the CustomTokenizer model\n", + "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n", + "\n", + "# from bark.generation import load_codec_model\n", + "# codec_model = load_codec_model(use_gpu=True)\n", + "# codec_model.eval()\n", + "# for param in codec_model.parameters():\n", + "# param.requires_grad = False\n", + "\n", + "\n", + "# def get_duration(wav, sr):\n", + "# return wav.shape[1] / sr\n", + "\n", + "# valid_lines_train = []\n", + "# # convert wavs to semantic tokens\n", + "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n", + "# wav, sr = torchaudio.load(wav_path)\n", + "# if not get_duration(wav, sr) > max_duration_sec:\n", + "# valid_lines_train.append((wav_path, txt))\n", + "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + "# # save semantic tokens\n", + "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + "# semantic_tokens = semantic_tokens.cpu().numpy()\n", + "\n", + "# # Extract discrete codes from EnCodec\n", + "# with torch.no_grad():\n", + "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + "# # move codes to cpu\n", + "# codes = codes.cpu().numpy()\n", + "\n", + "# # save tokens\n", + "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# # rewrite train.txt with valid lines\n", + "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n", + "# for wav_path, txt in valid_lines_train:\n", + "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + "# f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "# valid_lines_valid = []\n", + "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n", + "# wav, sr = torchaudio.load(wav_path)\n", + "# if not get_duration(wav, sr) > max_duration_sec:\n", + "# valid_lines_valid.append((wav_path, txt))\n", + "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + "# # save semantic tokens\n", + "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + "# semantic_tokens = semantic_tokens.cpu().numpy()\n", + " \n", + "# # Extract discrete codes from EnCodec\n", + "# with torch.no_grad():\n", + "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + "# # move codes to cpu\n", + "# codes = codes.cpu().numpy()\n", + "\n", + "# # save tokens\n", + "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# # rewrite valid.txt with valid lines\n", + "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n", + "# for wav_path, txt in valid_lines_valid:\n", + "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + "# f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "# del hubert_model\n", + "# del hubert_tokenizer\n", + "# del codec_model\n", + "# gc.collect()\n", + "# torch.cuda.empty_cache()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if scale_lr:\n", + " learning_rate = (\n", + " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n", + " )\n", + "\n", + "if use_8bit_adam:\n", + " try:\n", + " import bitsandbytes as bnb\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n", + " )\n", + "\n", + " optimizer_class = bnb.optim.AdamW8bit\n", + "else:\n", + " optimizer_class = torch.optim.AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=bits == 4,\n", + " load_in_8bit=bits == 8,\n", + " llm_int8_threshold=6.0,\n", + " llm_int8_has_fp16_weight=False,\n", + " bnb_4bit_compute_dtype=compute_dtype,\n", + " bnb_4bit_use_double_quant=double_quant,\n", + " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n", + ")\n", + "\n", + "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n", + "# if quantization_config.load_in_8bit:\n", + "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n", + "# elif quantization_config.load_in_4bit:\n", + "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n", + "\n", + "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n", + "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n", + "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n", + "# else:\n", + "# modules_to_not_convert = llm_int8_skip_modules\n", + "\n", + "# if not isinstance(modules_to_not_convert, list):\n", + "# modules_to_not_convert = [modules_to_not_convert]\n", + "\n", + "# modules_to_not_convert.extend(keep_in_fp32_modules)\n", + "\n", + "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n", + "\n", + "# if quantization_config.load_in_4bit and not supports_4bit:\n", + "# raise ValueError(\n", + "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n", + "# \" make sure you have the latest version of `bitsandbytes` installed\"\n", + "# )\n", + " \n", + "# if len(modules_to_not_convert) == 0:\n", + "# modules_to_not_convert = None\n", + "\n", + "# model = replace_with_bnb_linear(\n", + "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n", + "# )\n", + "\n", + "# # training in 8-bit is only available in 0.37.0+\n", + "# model._is_kbit_training_enabled = version.parse(\n", + "# importlib_metadata.version(\"bitsandbytes\")\n", + "# ) >= version.parse(\"0.37.0\")\n", + "\n", + "# model.config.quantization_config = quantization_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if bits == 4:\n", + " from accelerate.utils import CustomDtype\n", + " target_dtype = CustomDtype.INT4\n", + "elif bits == 8:\n", + " target_dtype = torch.int8\n", + "\n", + "if lora_dim > 0:\n", + " for param in model.parameters():\n", + " if param.ndim == 1:\n", + " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", + " param.data = param.data.to(torch.float32)\n", + " \n", + " class CastOutputToFloat(nn.Sequential):\n", + " def forward(self, x):\n", + " return super().forward(x).to(torch.float32)\n", + "\n", + " model.lm_head = CastOutputToFloat(model.lm_head)\n", + "\n", + " model = convert_linear_layer_to_lora(model, lora_module_name,\n", + " lora_dim=lora_dim, lora_scaling=lora_scaling,\n", + " lora_dropout=lora_dropout)\n", + " if optimize_lora_params_only:\n", + " model = only_optimize_lora_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params_to_optimize = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + "\n", + "optimizer = optimizer_class(\n", + " params_to_optimize,\n", + " lr=learning_rate,\n", + " betas=(adam_beta1, adam_beta2),\n", + " weight_decay=weight_decay,\n", + " eps=adam_epsilon,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "opt_train = {\n", + " 'path': dataset_path,\n", + " 'mode': 'train',\n", + "}\n", + "\n", + "opt_val = {\n", + " 'path': dataset_path,\n", + " 'mode': 'valid',\n", + "}\n", + "\n", + "train_dataset = TtsDataset(opt_train)\n", + "validation_dataset = TtsDataset(opt_val)\n", + "\n", + "train_dataloader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=train_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "validation_dataloader = torch.utils.data.DataLoader(\n", + " validation_dataset,\n", + " batch_size=eval_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n", + "\n", + "# Scheduler and math around the number of training steps.\n", + "overrode_max_train_steps = False\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if max_train_steps is None:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + " overrode_max_train_steps = True\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " lr_scheduler_type,\n", + " optimizer=optimizer,\n", + " num_warmup_steps=lr_warmup_steps * grad_accum,\n", + " num_training_steps=max_train_steps * grad_accum,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n", + " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n", + ")\n", + "accelerator.register_for_checkpointing(lr_scheduler)\n", + "\n", + "weight_dtype = torch.float32\n", + "if accelerator.mixed_precision == \"fp16\":\n", + " weight_dtype = torch.float16\n", + "elif accelerator.mixed_precision == \"bf16\":\n", + " weight_dtype = torch.bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if overrode_max_train_steps:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + "# Afterwards we recalculate our number of training epochs\n", + "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n", + "\n", + "# We need to initialize the trackers we use, and also store our configuration.\n", + "# The trackers initializes automatically on the main process.\n", + "if accelerator.is_main_process:\n", + " accelerator.init_trackers(\"bark_coarse\", config={})\n", + "\n", + "# Train!\n", + "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n", + "logger.info(\"***** Running training *****\")\n", + "logger.info(f\" Num examples = {len(train_dataset)}\")\n", + "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n", + "logger.info(f\" Num Epochs = {num_train_epochs}\")\n", + "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n", + "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", + "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n", + "logger.info(f\" Total optimization steps = {max_train_steps}\")\n", + "global_step = 0\n", + "first_epoch = 0\n", + "\n", + "if resume_from_checkpoint:\n", + " if resume_from_checkpoint != \"latest\":\n", + " path = os.path.basename(resume_from_checkpoint)\n", + " else:\n", + " # Get the most recent checkpoint\n", + " dirs = os.listdir(output_dir)\n", + " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", + " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", + " path = dirs[-1]\n", + " accelerator.print(f\"Resuming from checkpoint {path}\")\n", + " accelerator.load_state(os.path.join(output_dir, path))\n", + " global_step = int(path.split(\"-\")[1])\n", + "\n", + " resume_global_step = global_step * grad_accum\n", + " first_epoch = resume_global_step // num_update_steps_per_epoch\n", + " resume_step = resume_global_step % num_update_steps_per_epoch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n", + " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n", + " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n", + "\n", + " # Forward pass for validation\n", + " val_logits = model(val_inputs, training=True)\n", + " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n", + "\n", + " # Calculate the validation loss\n", + " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", + " validation_loss += val_loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['semantic_tokens'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Only show the progress bar once on each machine.\n", + "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n", + "progress_bar.set_description(\"Steps\")\n", + "\n", + "for epoch in range(first_epoch, num_train_epochs):\n", + " model.train()\n", + " for step, batch in enumerate(train_dataloader):\n", + " # Skip steps until we reach the resumed step\n", + " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n", + " if step % grad_accum == 0:\n", + " progress_bar.update(1)\n", + " continue\n", + "\n", + " with accelerator.accumulate(model):\n", + " targets = batch['coarse_tokens'][:, 1:].contiguous()\n", + " \n", + " # Remove the last coarse token from the inputs since there is no target for it.\n", + " coarse_inputs = batch['coarse_tokens'][:, :-1]\n", + "\n", + " # Combine the semantic tokens and coarse tokens and feed them into the model.\n", + " inputs = torch.cat([batch['semantic_tokens'], coarse_inputs], dim=1)\n", + " logits = model(inputs, training=True)\n", + "\n", + " # We're only interested in the logits for the coarse tokens, so we ignore the logits for the input text tokens.\n", + " coarse_logits = logits[:, batch['semantic_tokens'].size(1):].contiguous()\n", + "\n", + " # Compute the loss.\n", + " loss = criterion(coarse_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n", + "\n", + " if semantic_cross_entropy_loss_weight > 0 and semantic_cross_entropy_loss_weight is not None:\n", + " semantic_logits = logits[:, :batch['semantic_tokens'].size(1)].contiguous()\n", + " semantic_loss = criterion(\n", + " semantic_logits.view(-1, model.config.input_vocab_size),\n", + " batch['semantic_tokens'].view(-1),\n", + " )\n", + " num_semantic_logits = semantic_logits.size(1)\n", + " num_coarse_logits = coarse_logits.size(1)\n", + " loss = (\n", + " semantic_loss * num_semantic_logits * semantic_cross_entropy_loss_weight +\n", + " loss * num_coarse_logits\n", + " ) / (num_semantic_logits + num_coarse_logits)\n", + "\n", + " accelerator.backward(loss)\n", + " if accelerator.sync_gradients:\n", + " params_to_clip = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " # Checks if the accelerator has performed an optimization step behind the scenes\n", + " if accelerator.sync_gradients:\n", + " progress_bar.update(1)\n", + " global_step += 1\n", + "\n", + " if global_step % checkpointing_steps == 0:\n", + " if accelerator.is_main_process:\n", + " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n", + " accelerator.save_state(save_path)\n", + " logger.info(f\"Saved state to {save_path}\")\n", + "\n", + " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", + " progress_bar.set_postfix(**logs)\n", + " accelerator.log(logs, step=global_step)\n", + "\n", + " if global_step >= max_train_steps:\n", + " break\n", + " \n", + " accelerator.wait_for_everyone()\n", + "\n", + "if accelerator.is_main_process:\n", + " if lora_dim > 0:\n", + " model = convert_lora_to_linear_layer(model)\n", + " # save model\n", + " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n", + " \n", + " config = model.config.__dict__\n", + " # save config\n", + " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n", + " json.dump(config, f, indent=2)\n", + "\n", + "accelerator.end_training()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n", + " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n", + " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n", + "\n", + " # Forward pass for validation\n", + " val_logits = model(val_inputs, training=True)\n", + " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n", + "\n", + " # Calculate the validation loss\n", + " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", + " validation_loss += val_loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['semantic_tokens'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/train_fine.ipynb b/train_fine.ipynb new file mode 100644 index 0000000..8e80d44 --- /dev/null +++ b/train_fine.ipynb @@ -0,0 +1,1141 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:WARNING: Could not find module 'C:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\xformers\\_C.pyd' (or one of its dependencies). Try using the full path with constructor syntax.\n", + "Need to compile C++ extensions to get sparse attention suport. Please run python setup.py build develop\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Could not find module 'C:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\xformers\\_C.pyd' (or one of its dependencies). Try using the full path with constructor syntax.\n", + "\n", + "===================================BUG REPORT===================================\n", + "Welcome to bitsandbytes. For bug reports, please run\n", + "\n", + "python -m bitsandbytes\n", + "\n", + " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", + "================================================================================\n", + "bin c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda116.dll\n", + "function 'cadam32bit_grad_fp32' not found\n", + "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...\n", + "CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!\n", + "CUDA SETUP: Loading binary c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda116.dll...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", + " warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n", + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('vs/workbench/api/node/extensionHostProcess')}\n", + " warn(msg)\n", + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('module'), WindowsPath('/matplotlib_inline.backend_inline')}\n", + " warn(msg)\n", + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('/usr/local/cuda/lib64')}\n", + " warn(msg)\n", + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!\n", + " warn(msg)\n", + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...\n", + " warn(msg)\n" + ] + } + ], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import os\n", + "import re\n", + "import gc\n", + "import json\n", + "import math\n", + "import hashlib\n", + "import numpy as np\n", + "import logging\n", + "import torchaudio\n", + "from tqdm.auto import tqdm\n", + "import torch.nn.functional as F\n", + "from encodec.utils import convert_audio\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import set_seed\n", + "from transformers import BertTokenizer\n", + "from huggingface_hub import hf_hub_download\n", + "from packaging import version\n", + "from diffusers.optimization import get_scheduler\n", + "\n", + "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n", + "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n", + "from bark.model import GPTConfig, GPT\n", + "from bark.model_fine import FineGPT, FineGPTConfig" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training Args" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "train_batch_size = 8\n", + "eval_batch_size = 8\n", + "grad_accum = 1\n", + "ckpt_path = 'models/fine_2.pt'\n", + "model_type = \"fine\"\n", + "dataset_path = 'datasets/joe_biden_state_of_union/'\n", + "logging_dir = 'logs/'\n", + "log_with = 'wandb'\n", + "hubert_path = 'data/models/hubert/hubert.pt'\n", + "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n", + "\n", + "output_dir = 'fine_output/'\n", + "resume_from_checkpoint = None\n", + "\n", + "checkpointing_steps = 1000\n", + "\n", + "mixed_precision = 'bf16'\n", + "bits = 16 #4 4 and 8 bit are a work in progress\n", + "compute_dtype = torch.bfloat16\n", + "double_quant = True\n", + "quant_type = 'nf4'\n", + "\n", + "lora_dim = 64\n", + "lora_scaling = 1\n", + "lora_dropout = 0.1\n", + "lora_module_name = 'transformer.h'\n", + "optimize_lora_params_only = True\n", + "\n", + "learning_rate = 1e-4\n", + "scale_lr = False\n", + "use_8bit_adam = False\n", + "adam_beta1 = 0.9\n", + "adam_beta2 = 0.999\n", + "adam_epsilon = 1e-8\n", + "weight_decay = 0.01\n", + "\n", + "llm_int8_skip_modules = None\n", + "keep_in_fp32_modules = ['lm_head']\n", + "\n", + "lr_scheduler_type = 'linear'\n", + "lr_warmup_steps = 100\n", + "num_train_epochs = 5\n", + "max_train_steps = None\n", + "max_grad_norm = 1.0\n", + "\n", + "semantic_cross_entropy_loss_weight = 0\n", + "\n", + "seed = 741" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define Functions" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\accelerate\\accelerator.py:258: FutureWarning: `logging_dir` is deprecated and will be removed in version 0.18.0 of 🤗 Accelerate. Use `project_dir` instead.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "CONTEXT_WINDOW_SIZE = 1024\n", + "\n", + "MAX_SEMANTIC_LEN = 256\n", + "\n", + "SEMANTIC_RATE_HZ = 49.9\n", + "SEMANTIC_VOCAB_SIZE = 10_000\n", + "\n", + "TEXT_ENCODING_OFFSET = 10_048\n", + "SEMANTIC_PAD_TOKEN = 10_000\n", + "TEXT_PAD_TOKEN = 129_595\n", + "SEMANTIC_INFER_TOKEN = 129_599\n", + "\n", + "MAX_COARSE_LEN = 768\n", + "\n", + "SAMPLE_RATE = 24_000\n", + "CHANNELS = 1\n", + "\n", + "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n", + "COARSE_INFER_TOKEN = 12_050\n", + "\n", + "CODEBOOK_SIZE = 1024\n", + "N_COARSE_CODEBOOKS = 2\n", + "N_FINE_CODEBOOKS = 8\n", + "COARSE_RATE_HZ = 75\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n", + "\n", + "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n", + "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n", + "\n", + "\n", + "def _clear_cuda_cache():\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "def _md5(fname):\n", + " hash_md5 = hashlib.md5()\n", + " with open(fname, \"rb\") as f:\n", + " for chunk in iter(lambda: f.read(4096), b\"\"):\n", + " hash_md5.update(chunk)\n", + " return hash_md5.hexdigest()\n", + "\n", + "\n", + "def _download(from_hf_path, file_name, to_local_path):\n", + " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n", + " path = '/'.join(to_local_path.split(\"/\")[:-1])\n", + " os.makedirs(path, exist_ok=True)\n", + " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n", + " os.replace(os.path.join(path, file_name), to_local_path)\n", + "\n", + "\n", + "def _tokenize(tokenizer, text):\n", + " return tokenizer.encode(text, add_special_tokens=False)\n", + "\n", + "\n", + "def _detokenize(tokenizer, enc_text):\n", + " return tokenizer.decode(enc_text)\n", + "\n", + "\n", + "def _normalize_whitespace(text):\n", + " return re.sub(r\"\\s+\", \" \", text).strip()\n", + "\n", + "\n", + "REMOTE_MODEL_PATHS = {\n", + " \"text_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text.pt\",\n", + " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n", + " },\n", + " \"coarse_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse.pt\",\n", + " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n", + " },\n", + " \"fine_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine.pt\",\n", + " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n", + " },\n", + " \"text\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text_2.pt\",\n", + " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n", + " },\n", + " \"coarse\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse_2.pt\",\n", + " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n", + " },\n", + " \"fine\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine_2.pt\",\n", + " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n", + " },\n", + "}\n", + "\n", + "\n", + "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n", + " if model_type == \"text\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"coarse\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"fine\":\n", + " ConfigClass = FineGPTConfig\n", + " ModelClass = FineGPT\n", + " else:\n", + " raise NotImplementedError()\n", + " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n", + " model_info = REMOTE_MODEL_PATHS[model_key]\n", + " if ckpt_path in [None, '']:\n", + " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n", + " if not os.path.exists(ckpt_path):\n", + " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n", + " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " # this is a hack\n", + " model_args = checkpoint[\"model_args\"]\n", + " if \"input_vocab_size\" not in model_args:\n", + " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n", + " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n", + " del model_args[\"vocab_size\"]\n", + " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n", + " model = ModelClass(gptconf)\n", + " state_dict = checkpoint[\"model\"]\n", + " # fixup checkpoint\n", + " unwanted_prefix = \"_orig_mod.\"\n", + " for k, v in list(state_dict.items()):\n", + " if k.startswith(unwanted_prefix):\n", + " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n", + " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n", + " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n", + " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n", + " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n", + " if len(extra_keys) != 0:\n", + " raise ValueError(f\"extra keys found: {extra_keys}\")\n", + " if len(missing_keys) != 0:\n", + " raise ValueError(f\"missing keys: {missing_keys}\")\n", + " model.load_state_dict(state_dict, strict=False)\n", + " n_params = model.get_num_params()\n", + " val_loss = checkpoint[\"best_val_loss\"].item()\n", + " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n", + " del checkpoint, state_dict\n", + " _clear_cuda_cache()\n", + " if model_type == \"text\":\n", + " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n", + " return model, tokenizer\n", + " return model\n", + "\n", + "\n", + "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n", + " assert len(arr.shape) == 2\n", + " arr = arr.copy()\n", + " if offset_size is not None:\n", + " for n in range(1, arr.shape[0]):\n", + " arr[n, :] += offset_size * n\n", + " flat_arr = arr.ravel(\"F\")\n", + " return flat_arr\n", + "\n", + "\n", + "def load_filepaths_and_text(filename, split=\"|\"):\n", + " with open(filename, encoding='utf-8') as f:\n", + " filepaths_and_text = [line.strip().split(split) for line in f]\n", + " base = os.path.dirname(filename)\n", + " for j in range(len(filepaths_and_text)):\n", + " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n", + " return filepaths_and_text\n", + "\n", + "\n", + "class TtsDataset(torch.utils.data.Dataset):\n", + " def __init__(self, opt):\n", + " self.path = os.path.dirname(opt['path'])\n", + " self.mode = opt['mode']\n", + " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n", + "\n", + " def __getitem__(self, index):\n", + " audiopath_and_text = self.audiopaths_and_text[index]\n", + " audiopath = audiopath_and_text[0]\n", + "\n", + " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n", + " fine_tokens = tokens['fine']\n", + "\n", + " return torch.from_numpy(fine_tokens)\n", + "\n", + " def __len__(self):\n", + " return len(self.audiopaths_and_text)\n", + "\n", + "\n", + "class TtsCollater():\n", + " def __init__(self):\n", + " pass\n", + " def __call__(self, batch):\n", + " max_len = 1024\n", + " fine_tokens = []\n", + "\n", + " for fine_tokens_ in batch:\n", + " if fine_tokens_.shape[1] > max_len:\n", + " start_idx = np.random.randint(0, fine_tokens_.shape[1] - max_len + 1)\n", + " fine_tokens_ = fine_tokens_[:, start_idx : start_idx + max_len]\n", + "\n", + " pad_size = max_len - fine_tokens_.shape[1]\n", + " fine_tokens_ = F.pad(fine_tokens_, (0, pad_size), value=CODEBOOK_SIZE)\n", + "\n", + " fine_tokens_ = fine_tokens_.T\n", + "\n", + " fine_tokens.append(fine_tokens_)\n", + "\n", + " return {'fine_tokens': torch.stack(fine_tokens).contiguous()}\n", + " \n", + "\n", + "accelerator = Accelerator(\n", + " gradient_accumulation_steps=grad_accum,\n", + " mixed_precision=mixed_precision,\n", + " log_with=log_with,\n", + " logging_dir=logging_dir,\n", + ")\n", + "device = accelerator.device\n", + "\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "set_seed(seed)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup Dataset (only need to do this once)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n", + "\n", + "# path = dataset_path\n", + "\n", + "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", + "# from hubert.hubert_manager import HuBERTManager\n", + "# hubert_manager = HuBERTManager()\n", + "# from hubert.pre_kmeans_hubert import CustomHubert\n", + "# from hubert.customtokenizer import CustomTokenizer\n", + "# hubert_manager.make_sure_hubert_installed()\n", + "# hubert_manager.make_sure_tokenizer_installed()\n", + "\n", + "# # Load the HuBERT model\n", + "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n", + "# hubert_model.eval()\n", + "# for param in hubert_model.parameters():\n", + "# param.requires_grad = False\n", + "\n", + "# # Load the CustomTokenizer model\n", + "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n", + "\n", + "# from bark.generation import load_codec_model\n", + "# codec_model = load_codec_model(use_gpu=True)\n", + "# codec_model.eval()\n", + "# for param in codec_model.parameters():\n", + "# param.requires_grad = False\n", + "\n", + "\n", + "# def get_duration(wav, sr):\n", + "# return wav.shape[1] / sr\n", + "\n", + "# valid_lines_train = []\n", + "# # convert wavs to semantic tokens\n", + "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n", + "# wav, sr = torchaudio.load(wav_path)\n", + "# if not get_duration(wav, sr) > max_duration_sec:\n", + "# valid_lines_train.append((wav_path, txt))\n", + "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + "# # save semantic tokens\n", + "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + "# semantic_tokens = semantic_tokens.cpu().numpy()\n", + "\n", + "# # Extract discrete codes from EnCodec\n", + "# with torch.no_grad():\n", + "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + "# # move codes to cpu\n", + "# codes = codes.cpu().numpy()\n", + "\n", + "# # save tokens\n", + "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# # rewrite train.txt with valid lines\n", + "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n", + "# for wav_path, txt in valid_lines_train:\n", + "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + "# f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "# valid_lines_valid = []\n", + "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n", + "# wav, sr = torchaudio.load(wav_path)\n", + "# if not get_duration(wav, sr) > max_duration_sec:\n", + "# valid_lines_valid.append((wav_path, txt))\n", + "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + "# # save semantic tokens\n", + "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + "# semantic_tokens = semantic_tokens.cpu().numpy()\n", + " \n", + "# # Extract discrete codes from EnCodec\n", + "# with torch.no_grad():\n", + "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + "# # move codes to cpu\n", + "# codes = codes.cpu().numpy()\n", + "\n", + "# # save tokens\n", + "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# # rewrite valid.txt with valid lines\n", + "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n", + "# for wav_path, txt in valid_lines_valid:\n", + "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + "# f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "# del hubert_model\n", + "# del hubert_tokenizer\n", + "# del codec_model\n", + "# gc.collect()\n", + "# torch.cuda.empty_cache()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loaded fine model with 302090240 params, val_loss=2.0786.\n" + ] + } + ], + "source": [ + "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "if scale_lr:\n", + " learning_rate = (\n", + " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n", + " )\n", + "\n", + "if use_8bit_adam:\n", + " try:\n", + " import bitsandbytes as bnb\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n", + " )\n", + "\n", + " optimizer_class = bnb.optim.AdamW8bit\n", + "else:\n", + " optimizer_class = torch.optim.AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=bits == 4,\n", + " load_in_8bit=bits == 8,\n", + " llm_int8_threshold=6.0,\n", + " llm_int8_has_fp16_weight=False,\n", + " bnb_4bit_compute_dtype=compute_dtype,\n", + " bnb_4bit_use_double_quant=double_quant,\n", + " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n", + ")\n", + "\n", + "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n", + "# if quantization_config.load_in_8bit:\n", + "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n", + "# elif quantization_config.load_in_4bit:\n", + "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n", + "\n", + "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n", + "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n", + "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n", + "# else:\n", + "# modules_to_not_convert = llm_int8_skip_modules\n", + "\n", + "# if not isinstance(modules_to_not_convert, list):\n", + "# modules_to_not_convert = [modules_to_not_convert]\n", + "\n", + "# modules_to_not_convert.extend(keep_in_fp32_modules)\n", + "\n", + "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n", + "\n", + "# if quantization_config.load_in_4bit and not supports_4bit:\n", + "# raise ValueError(\n", + "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n", + "# \" make sure you have the latest version of `bitsandbytes` installed\"\n", + "# )\n", + " \n", + "# if len(modules_to_not_convert) == 0:\n", + "# modules_to_not_convert = None\n", + "\n", + "# model = replace_with_bnb_linear(\n", + "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n", + "# )\n", + "\n", + "# # training in 8-bit is only available in 0.37.0+\n", + "# model._is_kbit_training_enabled = version.parse(\n", + "# importlib_metadata.version(\"bitsandbytes\")\n", + "# ) >= version.parse(\"0.37.0\")\n", + "\n", + "# model.config.quantization_config = quantization_config" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "if bits == 4:\n", + " from accelerate.utils import CustomDtype\n", + " target_dtype = CustomDtype.INT4\n", + "elif bits == 8:\n", + " target_dtype = torch.int8\n", + "\n", + "if lora_dim > 0:\n", + " for param in model.parameters():\n", + " if param.ndim == 1:\n", + " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", + " param.data = param.data.to(torch.float32)\n", + " \n", + " class CastOutputToFloat(nn.Sequential):\n", + " def forward(self, x):\n", + " return super().forward(x).to(torch.float32)\n", + "\n", + " # model.lm_head = CastOutputToFloat(model.lm_head)\n", + " for i, lm_head in enumerate(model.lm_heads):\n", + " model.lm_heads[i] = CastOutputToFloat(lm_head)\n", + "\n", + " model = convert_linear_layer_to_lora(model, lora_module_name,\n", + " lora_dim=lora_dim, lora_scaling=lora_scaling,\n", + " lora_dropout=lora_dropout)\n", + " if optimize_lora_params_only:\n", + " model = only_optimize_lora_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "params_to_optimize = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + "\n", + "optimizer = optimizer_class(\n", + " params_to_optimize,\n", + " lr=learning_rate,\n", + " betas=(adam_beta1, adam_beta2),\n", + " weight_decay=weight_decay,\n", + " eps=adam_epsilon,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "opt_train = {\n", + " 'path': dataset_path,\n", + " 'mode': 'train',\n", + "}\n", + "\n", + "opt_val = {\n", + " 'path': dataset_path,\n", + " 'mode': 'valid',\n", + "}\n", + "\n", + "train_dataset = TtsDataset(opt_train)\n", + "validation_dataset = TtsDataset(opt_val)\n", + "\n", + "train_dataloader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=train_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "validation_dataloader = torch.utils.data.DataLoader(\n", + " validation_dataset,\n", + " batch_size=eval_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n", + "\n", + "# Scheduler and math around the number of training steps.\n", + "overrode_max_train_steps = False\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if max_train_steps is None:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + " overrode_max_train_steps = True\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " lr_scheduler_type,\n", + " optimizer=optimizer,\n", + " num_warmup_steps=lr_warmup_steps * grad_accum,\n", + " num_training_steps=max_train_steps * grad_accum,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n", + " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n", + ")\n", + "accelerator.register_for_checkpointing(lr_scheduler)\n", + "\n", + "weight_dtype = torch.float32\n", + "if accelerator.mixed_precision == \"fp16\":\n", + " weight_dtype = torch.float16\n", + "elif accelerator.mixed_precision == \"bf16\":\n", + " weight_dtype = torch.bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mfrancislabounty\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "data": { + "text/html": [ + "wandb version 0.15.4 is available! To upgrade, please run:\n", + " $ pip install wandb --upgrade" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.13.6" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in e:\\Python\\bark-with-voice-clone\\wandb\\run-20230629_202416-290ebk11" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run fresh-pyramid-26 to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if overrode_max_train_steps:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + "# Afterwards we recalculate our number of training epochs\n", + "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n", + "\n", + "# We need to initialize the trackers we use, and also store our configuration.\n", + "# The trackers initializes automatically on the main process.\n", + "if accelerator.is_main_process:\n", + " accelerator.init_trackers(\"bark_coarse\", config={})\n", + "\n", + "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n", + "logger.info(\"***** Running training *****\")\n", + "logger.info(f\" Num examples = {len(train_dataset)}\")\n", + "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n", + "logger.info(f\" Num Epochs = {num_train_epochs}\")\n", + "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n", + "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", + "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n", + "logger.info(f\" Total optimization steps = {max_train_steps}\")\n", + "global_step = 0\n", + "first_epoch = 0\n", + "\n", + "if resume_from_checkpoint:\n", + " if resume_from_checkpoint != \"latest\":\n", + " path = os.path.basename(resume_from_checkpoint)\n", + " else:\n", + " # Get the most recent checkpoint\n", + " dirs = os.listdir(output_dir)\n", + " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", + " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", + " path = dirs[-1]\n", + " accelerator.print(f\"Resuming from checkpoint {path}\")\n", + " accelerator.load_state(os.path.join(output_dir, path))\n", + " global_step = int(path.split(\"-\")[1])\n", + "\n", + " resume_global_step = global_step * grad_accum\n", + " first_epoch = resume_global_step // num_update_steps_per_epoch\n", + " resume_step = resume_global_step % num_update_steps_per_epoch\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 30.702054630626332 over 82 samples and 11 batches.\n" + ] + } + ], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n", + " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n", + " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n", + " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n", + "\n", + " # Forward pass for validation\n", + " logits_7 = model(6, fine_tokens_input_7)\n", + " logits_8 = model(7, fine_tokens_input_8)\n", + "\n", + " # Calculate the validation loss\n", + " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", + " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", + "\n", + " loss = loss_7 + loss_8\n", + " validation_loss += loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['fine_tokens'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3761107b0c094d2db6532410a582408c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/205 [00:00(success)." + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "45d15ef5bc1e4729aaf827aa3380823d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "

Run history:


loss█▆█▅▅▇▆▆▆▇▆▆▅▄▃▃▄▂▃▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇███▇▇▇▆▆▆▅▅▅▄▄▄▃▃▂▂▂▁▁

Run summary:


loss3.18231
lr0.0

" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Synced fresh-pyramid-26: https://wandb.ai/francislabounty/bark_coarse/runs/290ebk11
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Find logs at: .\\wandb\\run-20230629_202416-290ebk11\\logs" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Only show the progress bar once on each machine.\n", + "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n", + "progress_bar.set_description(\"Steps\")\n", + "\n", + "for epoch in range(first_epoch, num_train_epochs):\n", + " model.train()\n", + " for step, batch in enumerate(train_dataloader):\n", + " # Skip steps until we reach the resumed step\n", + " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n", + " if step % grad_accum == 0:\n", + " progress_bar.update(1)\n", + " continue\n", + "\n", + " with accelerator.accumulate(model):\n", + " fine_targets_7 = batch['fine_tokens'][:, :, 6]\n", + " fine_tokens_input_7 = torch.cat([batch['fine_tokens'][:, :, :6], torch.zeros_like(batch['fine_tokens'][:, :, 6:])], dim=2)\n", + " fine_targets_8 = batch['fine_tokens'][:, :, 7]\n", + " fine_tokens_input_8 = torch.cat([batch['fine_tokens'][:, :, :7], torch.zeros_like(batch['fine_tokens'][:, :, 7:])], dim=2)\n", + "\n", + " # Forward pass\n", + " logits_7 = model(6, fine_tokens_input_7)\n", + " logits_8 = model(7, fine_tokens_input_8)\n", + "\n", + " # Calculate the loss\n", + " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", + " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", + "\n", + " loss = loss_7 + loss_8\n", + "\n", + " accelerator.backward(loss)\n", + " if accelerator.sync_gradients:\n", + " params_to_clip = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " # Checks if the accelerator has performed an optimization step behind the scenes\n", + " if accelerator.sync_gradients:\n", + " progress_bar.update(1)\n", + " global_step += 1\n", + "\n", + " if global_step % checkpointing_steps == 0:\n", + " if accelerator.is_main_process:\n", + " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n", + " accelerator.save_state(save_path)\n", + " logger.info(f\"Saved state to {save_path}\")\n", + "\n", + " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", + " progress_bar.set_postfix(**logs)\n", + " accelerator.log(logs, step=global_step)\n", + "\n", + " if global_step >= max_train_steps:\n", + " break\n", + " \n", + " accelerator.wait_for_everyone()\n", + "\n", + "if accelerator.is_main_process:\n", + " if lora_dim > 0:\n", + " model = convert_lora_to_linear_layer(model)\n", + " # save model\n", + " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n", + " \n", + " config = model.config.__dict__\n", + " # save config\n", + " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n", + " json.dump(config, f, indent=2)\n", + "\n", + "accelerator.end_training()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Validation" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation Loss: 3.041703635996038 over 82 samples and 11 batches.\n" + ] + } + ], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n", + " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n", + " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n", + " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n", + "\n", + " # Forward pass for validation\n", + " logits_7 = model(6, fine_tokens_input_7)\n", + " logits_8 = model(7, fine_tokens_input_8)\n", + "\n", + " # Calculate the validation loss\n", + " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", + " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", + "\n", + " loss = loss_7 + loss_8\n", + " validation_loss += loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['fine_tokens'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/train_semantic.ipynb b/train_semantic.ipynb new file mode 100644 index 0000000..a61dd72 --- /dev/null +++ b/train_semantic.ipynb @@ -0,0 +1,899 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import os\n", + "import re\n", + "import gc\n", + "import json\n", + "import math\n", + "import hashlib\n", + "import numpy as np\n", + "import logging\n", + "import torchaudio\n", + "from tqdm.auto import tqdm\n", + "import torch.nn.functional as F\n", + "from encodec.utils import convert_audio\n", + "from accelerate import Accelerator\n", + "from accelerate.utils import set_seed\n", + "from transformers import BertTokenizer\n", + "from huggingface_hub import hf_hub_download\n", + "from packaging import version\n", + "from diffusers.optimization import get_scheduler\n", + "\n", + "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n", + "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n", + "from bark.model import GPTConfig, GPT\n", + "from bark.model_fine import FineGPT, FineGPTConfig" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training Args" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "train_batch_size = 8\n", + "eval_batch_size = 8\n", + "grad_accum = 1\n", + "ckpt_path = 'models/text_2.pt'\n", + "model_type = \"text\"\n", + "dataset_path = 'datasets/joe_biden_state_of_union/'\n", + "logging_dir = 'logs/'\n", + "log_with = 'wandb'\n", + "hubert_path = 'data/models/hubert/hubert.pt'\n", + "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n", + "\n", + "output_dir = 'semantic_output/'\n", + "resume_from_checkpoint = None\n", + "\n", + "checkpointing_steps = 1000\n", + "\n", + "mixed_precision = 'bf16'\n", + "bits = 16 #4 4 and 8 bit are a work in progress\n", + "compute_dtype = torch.bfloat16\n", + "double_quant = True\n", + "quant_type = 'nf4'\n", + "\n", + "lora_dim = 64\n", + "lora_scaling = 32\n", + "lora_dropout = 0.1\n", + "lora_module_name = 'transformer.h'\n", + "optimize_lora_params_only = True\n", + "\n", + "learning_rate = 1e-4\n", + "scale_lr = False\n", + "use_8bit_adam = False\n", + "adam_beta1 = 0.9\n", + "adam_beta2 = 0.999\n", + "adam_epsilon = 1e-8\n", + "weight_decay = 0.01\n", + "\n", + "llm_int8_skip_modules = None\n", + "keep_in_fp32_modules = ['lm_head']\n", + "\n", + "lr_scheduler_type = 'linear'\n", + "lr_warmup_steps = 200\n", + "num_train_epochs = 20\n", + "max_train_steps = None\n", + "max_grad_norm = 1.0\n", + "\n", + "seed = 741" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Define Functions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CONTEXT_WINDOW_SIZE = 1024\n", + "\n", + "MAX_TEXT_LEN = 256\n", + "\n", + "SEMANTIC_RATE_HZ = 49.9\n", + "SEMANTIC_VOCAB_SIZE = 10_000\n", + "\n", + "TEXT_ENCODING_OFFSET = 10_048\n", + "SEMANTIC_PAD_TOKEN = 10_000\n", + "TEXT_PAD_TOKEN = 129_595\n", + "SEMANTIC_INFER_TOKEN = 129_599\n", + "\n", + "MAX_SEMANTIC_LEN = 511\n", + "\n", + "SAMPLE_RATE = 24_000\n", + "CHANNELS = 1\n", + "\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "\n", + "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n", + "\n", + "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n", + "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n", + "\n", + "\n", + "def _clear_cuda_cache():\n", + " if torch.cuda.is_available():\n", + " torch.cuda.empty_cache()\n", + " torch.cuda.synchronize()\n", + "\n", + "\n", + "def _md5(fname):\n", + " hash_md5 = hashlib.md5()\n", + " with open(fname, \"rb\") as f:\n", + " for chunk in iter(lambda: f.read(4096), b\"\"):\n", + " hash_md5.update(chunk)\n", + " return hash_md5.hexdigest()\n", + "\n", + "\n", + "def _download(from_hf_path, file_name, to_local_path):\n", + " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n", + " path = '/'.join(to_local_path.split(\"/\")[:-1])\n", + " os.makedirs(path, exist_ok=True)\n", + " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n", + " os.replace(os.path.join(path, file_name), to_local_path)\n", + "\n", + "\n", + "def _tokenize(tokenizer, text):\n", + " return tokenizer.encode(text, add_special_tokens=False)\n", + "\n", + "\n", + "def _detokenize(tokenizer, enc_text):\n", + " return tokenizer.decode(enc_text)\n", + "\n", + "\n", + "def _normalize_whitespace(text):\n", + " return re.sub(r\"\\s+\", \" \", text).strip()\n", + "\n", + "\n", + "REMOTE_MODEL_PATHS = {\n", + " \"text_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text.pt\",\n", + " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n", + " },\n", + " \"coarse_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse.pt\",\n", + " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n", + " },\n", + " \"fine_small\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine.pt\",\n", + " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n", + " },\n", + " \"text\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"text_2.pt\",\n", + " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n", + " },\n", + " \"coarse\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"coarse_2.pt\",\n", + " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n", + " },\n", + " \"fine\": {\n", + " \"repo_id\": \"suno/bark\",\n", + " \"file_name\": \"fine_2.pt\",\n", + " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n", + " },\n", + "}\n", + "\n", + "\n", + "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n", + " if model_type == \"text\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"coarse\":\n", + " ConfigClass = GPTConfig\n", + " ModelClass = GPT\n", + " elif model_type == \"fine\":\n", + " ConfigClass = FineGPTConfig\n", + " ModelClass = FineGPT\n", + " else:\n", + " raise NotImplementedError()\n", + " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n", + " model_info = REMOTE_MODEL_PATHS[model_key]\n", + " if ckpt_path in [None, '']:\n", + " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n", + " if not os.path.exists(ckpt_path):\n", + " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n", + " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n", + " checkpoint = torch.load(ckpt_path, map_location=device)\n", + " # this is a hack\n", + " model_args = checkpoint[\"model_args\"]\n", + " if \"input_vocab_size\" not in model_args:\n", + " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n", + " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n", + " del model_args[\"vocab_size\"]\n", + " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n", + " model = ModelClass(gptconf)\n", + " state_dict = checkpoint[\"model\"]\n", + " # fixup checkpoint\n", + " unwanted_prefix = \"_orig_mod.\"\n", + " for k, v in list(state_dict.items()):\n", + " if k.startswith(unwanted_prefix):\n", + " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n", + " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n", + " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n", + " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n", + " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n", + " if len(extra_keys) != 0:\n", + " raise ValueError(f\"extra keys found: {extra_keys}\")\n", + " if len(missing_keys) != 0:\n", + " raise ValueError(f\"missing keys: {missing_keys}\")\n", + " model.load_state_dict(state_dict, strict=False)\n", + " n_params = model.get_num_params()\n", + " val_loss = checkpoint[\"best_val_loss\"].item()\n", + " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n", + " del checkpoint, state_dict\n", + " _clear_cuda_cache()\n", + " if model_type == \"text\":\n", + " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n", + " return model, tokenizer\n", + " return model\n", + "\n", + "\n", + "def load_filepaths_and_text(filename, split=\"|\"):\n", + " with open(filename, encoding='utf-8') as f:\n", + " filepaths_and_text = [line.strip().split(split) for line in f]\n", + " base = os.path.dirname(filename)\n", + " for j in range(len(filepaths_and_text)):\n", + " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n", + " return filepaths_and_text\n", + "\n", + "class TtsDataset(torch.utils.data.Dataset):\n", + " def __init__(self, opt):\n", + " self.path = os.path.dirname(opt['path'])\n", + " self.mode = opt['mode']\n", + " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '_valid.txt'))\n", + " self.tokenizer = opt['tokenizer']\n", + "\n", + " def __getitem__(self, index):\n", + " audiopath_and_text = self.audiopaths_and_text[index]\n", + " audiopath, text = audiopath_and_text[0], audiopath_and_text[1]\n", + "\n", + " input_ids = np.array(_tokenize(self.tokenizer, text)) + TEXT_ENCODING_OFFSET\n", + " input_ids = torch.from_numpy(input_ids).long()\n", + " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n", + " semantic_tokens = tokens['semantic']\n", + " semantic_tokens = torch.from_numpy(semantic_tokens).long()\n", + "\n", + " return input_ids, semantic_tokens\n", + "\n", + " def __len__(self):\n", + " return len(self.audiopaths_and_text)\n", + "\n", + "\n", + "class TtsCollater():\n", + " def __init__(self):\n", + " pass\n", + " def __call__(self, batch):\n", + " max_text_len = MAX_TEXT_LEN\n", + " max_semantic_tokens_len = MAX_SEMANTIC_LEN\n", + " texts = []\n", + " semantic_tokens = []\n", + "\n", + " for b in batch:\n", + " text, semantic_tokens_ = b\n", + " text = F.pad(text, (0, max_text_len-len(text)), value=TEXT_PAD_TOKEN)\n", + " semantic_history = torch.from_numpy(np.array([SEMANTIC_PAD_TOKEN] * 256))\n", + " text = torch.cat([text, semantic_history, torch.tensor([SEMANTIC_INFER_TOKEN])])\n", + " texts.append(text)\n", + " semantic_tokens_ = semantic_tokens_[:max_semantic_tokens_len]\n", + " semantic_tokens.append(F.pad(semantic_tokens_, (0, max_semantic_tokens_len-len(semantic_tokens_)), value=SEMANTIC_PAD_TOKEN))\n", + "\n", + " return {\n", + " 'input_ids': torch.stack(texts).contiguous(),\n", + " 'semantic_tokens': torch.stack(semantic_tokens).contiguous()\n", + " }\n", + " \n", + "\n", + "accelerator = Accelerator(\n", + " gradient_accumulation_steps=grad_accum,\n", + " mixed_precision=mixed_precision,\n", + " log_with=log_with,\n", + " logging_dir=logging_dir,\n", + ")\n", + "device = accelerator.device\n", + "\n", + "os.makedirs(output_dir, exist_ok=True)\n", + "\n", + "set_seed(seed)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup Dataset (only need to do this once)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "max_duration_sec = 15.12 # the maximum allowed duration in seconds\n", + "\n", + "path = dataset_path\n", + "\n", + "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", + "from hubert.hubert_manager import HuBERTManager\n", + "hubert_manager = HuBERTManager()\n", + "from hubert.pre_kmeans_hubert import CustomHubert\n", + "from hubert.customtokenizer import CustomTokenizer\n", + "hubert_manager.make_sure_hubert_installed()\n", + "hubert_manager.make_sure_tokenizer_installed()\n", + "\n", + "# Load the HuBERT model\n", + "hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n", + "hubert_model.eval()\n", + "for param in hubert_model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "# Load the CustomTokenizer model\n", + "hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n", + "\n", + "from bark.generation import load_codec_model\n", + "codec_model = load_codec_model(use_gpu=True)\n", + "codec_model.eval()\n", + "for param in codec_model.parameters():\n", + " param.requires_grad = False\n", + "\n", + "\n", + "def get_duration(wav, sr):\n", + " return wav.shape[1] / sr\n", + "\n", + "valid_lines_train = []\n", + "# convert wavs to semantic tokens\n", + "for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n", + " wav, sr = torchaudio.load(wav_path)\n", + " if not get_duration(wav, sr) > max_duration_sec:\n", + " valid_lines_train.append((wav_path, txt))\n", + " wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + " semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + " semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + " # save semantic tokens\n", + " os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + " semantic_tokens = semantic_tokens.cpu().numpy()\n", + "\n", + " # Extract discrete codes from EnCodec\n", + " with torch.no_grad():\n", + " encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + " codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + " # move codes to cpu\n", + " codes = codes.cpu().numpy()\n", + "\n", + " # save tokens\n", + " np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# rewrite train.txt with valid lines\n", + "with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n", + " for wav_path, txt in valid_lines_train:\n", + " wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + " f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "valid_lines_valid = []\n", + "for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n", + " wav, sr = torchaudio.load(wav_path)\n", + " if not get_duration(wav, sr) > max_duration_sec:\n", + " valid_lines_valid.append((wav_path, txt))\n", + " wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", + "\n", + " semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", + " semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", + "\n", + " # save semantic tokens\n", + " os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", + " semantic_tokens = semantic_tokens.cpu().numpy()\n", + " \n", + " # Extract discrete codes from EnCodec\n", + " with torch.no_grad():\n", + " encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", + " codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", + "\n", + " # move codes to cpu\n", + " codes = codes.cpu().numpy()\n", + "\n", + " # save tokens\n", + " np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", + "\n", + "# rewrite valid.txt with valid lines\n", + "with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n", + " for wav_path, txt in valid_lines_valid:\n", + " wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", + " f.write(f'{wav_path}|{txt}\\n')\n", + "\n", + "del hubert_model\n", + "del hubert_tokenizer\n", + "del codec_model\n", + "gc.collect()\n", + "torch.cuda.empty_cache()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model, tokenizer = _load_model(ckpt_path, device, use_small=False, model_type=model_type)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if scale_lr:\n", + " learning_rate = (\n", + " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n", + " )\n", + "\n", + "if use_8bit_adam:\n", + " try:\n", + " import bitsandbytes as bnb\n", + " except ImportError:\n", + " raise ImportError(\n", + " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n", + " )\n", + "\n", + " optimizer_class = bnb.optim.AdamW8bit\n", + "else:\n", + " optimizer_class = torch.optim.AdamW" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "quantization_config=BitsAndBytesConfig(\n", + " load_in_4bit=bits == 4,\n", + " load_in_8bit=bits == 8,\n", + " llm_int8_threshold=6.0,\n", + " llm_int8_has_fp16_weight=False,\n", + " bnb_4bit_compute_dtype=compute_dtype,\n", + " bnb_4bit_use_double_quant=double_quant,\n", + " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n", + ")\n", + "\n", + "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n", + "# if quantization_config.load_in_8bit:\n", + "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n", + "# elif quantization_config.load_in_4bit:\n", + "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n", + "\n", + "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n", + "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n", + "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n", + "# else:\n", + "# modules_to_not_convert = llm_int8_skip_modules\n", + "\n", + "# if not isinstance(modules_to_not_convert, list):\n", + "# modules_to_not_convert = [modules_to_not_convert]\n", + "\n", + "# modules_to_not_convert.extend(keep_in_fp32_modules)\n", + "\n", + "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n", + "\n", + "# if quantization_config.load_in_4bit and not supports_4bit:\n", + "# raise ValueError(\n", + "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n", + "# \" make sure you have the latest version of `bitsandbytes` installed\"\n", + "# )\n", + " \n", + "# if len(modules_to_not_convert) == 0:\n", + "# modules_to_not_convert = None\n", + "\n", + "# model = replace_with_bnb_linear(\n", + "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n", + "# )\n", + "\n", + "# # training in 8-bit is only available in 0.37.0+\n", + "# model._is_kbit_training_enabled = version.parse(\n", + "# importlib_metadata.version(\"bitsandbytes\")\n", + "# ) >= version.parse(\"0.37.0\")\n", + "\n", + "# model.config.quantization_config = quantization_config" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if bits == 4:\n", + " from accelerate.utils import CustomDtype\n", + " target_dtype = CustomDtype.INT4\n", + "elif bits == 8:\n", + " target_dtype = torch.int8\n", + "\n", + "if lora_dim > 0:\n", + " for param in model.parameters():\n", + " if param.ndim == 1:\n", + " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", + " param.data = param.data.to(torch.float32)\n", + " \n", + " class CastOutputToFloat(nn.Sequential):\n", + " def forward(self, x):\n", + " return super().forward(x).to(torch.float32)\n", + "\n", + " model.lm_head = CastOutputToFloat(model.lm_head)\n", + "\n", + " model = convert_linear_layer_to_lora(model, lora_module_name,\n", + " lora_dim=lora_dim, lora_scaling=lora_scaling,\n", + " lora_dropout=lora_dropout)\n", + " if optimize_lora_params_only:\n", + " model = only_optimize_lora_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "params_to_optimize = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + "\n", + "optimizer = optimizer_class(\n", + " params_to_optimize,\n", + " lr=learning_rate,\n", + " betas=(adam_beta1, adam_beta2),\n", + " weight_decay=weight_decay,\n", + " eps=adam_epsilon,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "opt_train = {\n", + " 'path': dataset_path,\n", + " 'tokenizer': tokenizer,\n", + " 'mode': 'train',\n", + "}\n", + "\n", + "opt_val = {\n", + " 'path': dataset_path,\n", + " 'tokenizer': tokenizer,\n", + " 'mode': 'valid',\n", + "}\n", + "\n", + "train_dataset = TtsDataset(opt_train)\n", + "validation_dataset = TtsDataset(opt_val)\n", + "\n", + "train_dataloader = torch.utils.data.DataLoader(\n", + " train_dataset,\n", + " batch_size=train_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "validation_dataloader = torch.utils.data.DataLoader(\n", + " validation_dataset,\n", + " batch_size=eval_batch_size,\n", + " collate_fn=TtsCollater(),\n", + ")\n", + "\n", + "criterion = torch.nn.CrossEntropyLoss(ignore_index=SEMANTIC_PAD_TOKEN)\n", + "\n", + "# Scheduler and math around the number of training steps.\n", + "overrode_max_train_steps = False\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if max_train_steps is None:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + " overrode_max_train_steps = True\n", + "\n", + "lr_scheduler = get_scheduler(\n", + " lr_scheduler_type,\n", + " optimizer=optimizer,\n", + " num_warmup_steps=lr_warmup_steps * grad_accum,\n", + " num_training_steps=max_train_steps * grad_accum,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n", + " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n", + ")\n", + "accelerator.register_for_checkpointing(lr_scheduler)\n", + "\n", + "weight_dtype = torch.float32\n", + "if accelerator.mixed_precision == \"fp16\":\n", + " weight_dtype = torch.float16\n", + "elif accelerator.mixed_precision == \"bf16\":\n", + " weight_dtype = torch.bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", + "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", + "if overrode_max_train_steps:\n", + " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", + "# Afterwards we recalculate our number of training epochs\n", + "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n", + "\n", + "# We need to initialize the trackers we use, and also store our configuration.\n", + "# The trackers initializes automatically on the main process.\n", + "if accelerator.is_main_process:\n", + " accelerator.init_trackers(\"bark_semantic\", config={})\n", + "\n", + "# Train!\n", + "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n", + "logger.info(\"***** Running training *****\")\n", + "logger.info(f\" Num examples = {len(train_dataset)}\")\n", + "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n", + "logger.info(f\" Num Epochs = {num_train_epochs}\")\n", + "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n", + "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", + "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n", + "logger.info(f\" Total optimization steps = {max_train_steps}\")\n", + "global_step = 0\n", + "first_epoch = 0\n", + "\n", + "if resume_from_checkpoint:\n", + " if resume_from_checkpoint != \"latest\":\n", + " path = os.path.basename(resume_from_checkpoint)\n", + " else:\n", + " # Get the most recent checkpoint\n", + " dirs = os.listdir(output_dir)\n", + " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", + " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", + " path = dirs[-1]\n", + " accelerator.print(f\"Resuming from checkpoint {path}\")\n", + " accelerator.load_state(os.path.join(output_dir, path))\n", + " global_step = int(path.split(\"-\")[1])\n", + "\n", + " resume_global_step = global_step * grad_accum\n", + " first_epoch = resume_global_step // num_update_steps_per_epoch\n", + " resume_step = resume_global_step % num_update_steps_per_epoch\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n", + " val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n", + " val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n", + "\n", + " # Forward pass for validation\n", + " val_logits = model(val_inputs, training=True)\n", + " val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n", + "\n", + " # Calculate the validation loss\n", + " val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", + " validation_loss += val_loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['input_ids'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Only show the progress bar once on each machine.\n", + "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n", + "progress_bar.set_description(\"Steps\")\n", + "\n", + "for epoch in range(first_epoch, num_train_epochs):\n", + " model.train()\n", + " for step, batch in enumerate(train_dataloader):\n", + " # Skip steps until we reach the resumed step\n", + " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n", + " if step % grad_accum == 0:\n", + " progress_bar.update(1)\n", + " continue\n", + "\n", + " with accelerator.accumulate(model):\n", + " targets = batch['semantic_tokens'][:, 1:].contiguous()\n", + " \n", + " # Remove the last semantic token from the inputs since there is no target for it.\n", + " semantic_inputs = batch['semantic_tokens'][:, :-1]\n", + "\n", + " # Combine the text and semantic tokens and feed them into the model.\n", + " inputs = torch.cat([batch['input_ids'], semantic_inputs], dim=1)\n", + " logits = model(inputs, training=True)\n", + "\n", + " # We're only interested in the logits for the semantic tokens, so we ignore the logits for the input text tokens.\n", + " semantic_logits = logits[:, batch['input_ids'].size(1):].contiguous()\n", + "\n", + " # Compute the loss.\n", + " loss = criterion(semantic_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n", + "\n", + " accelerator.backward(loss)\n", + " if accelerator.sync_gradients:\n", + " params_to_clip = (\n", + " param for param in model.parameters() if param.requires_grad\n", + " )\n", + " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n", + " optimizer.step()\n", + " lr_scheduler.step()\n", + " optimizer.zero_grad()\n", + "\n", + " # Checks if the accelerator has performed an optimization step behind the scenes\n", + " if accelerator.sync_gradients:\n", + " progress_bar.update(1)\n", + " global_step += 1\n", + "\n", + " if global_step % checkpointing_steps == 0:\n", + " if accelerator.is_main_process:\n", + " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n", + " accelerator.save_state(save_path)\n", + " logger.info(f\"Saved state to {save_path}\")\n", + "\n", + " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", + " progress_bar.set_postfix(**logs)\n", + " accelerator.log(logs, step=global_step)\n", + "\n", + " if global_step >= max_train_steps:\n", + " break\n", + " \n", + " accelerator.wait_for_everyone()\n", + "\n", + "if accelerator.is_main_process:\n", + " if lora_dim > 0:\n", + " model = convert_lora_to_linear_layer(model)\n", + " # save model\n", + " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n", + "\n", + " config = model.config.__dict__\n", + " # save config\n", + " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n", + " json.dump(config, f, indent=2)\n", + "\n", + "accelerator.end_training()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Validation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if accelerator.is_main_process:\n", + " model.eval()\n", + " validation_loss = 0.0\n", + " num_batches = 0\n", + " num_samples = 0\n", + " with torch.no_grad():\n", + " for val_step, val_batch in enumerate(validation_dataloader):\n", + " # Similar to training, process the validation batch\n", + " val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n", + " val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n", + " val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n", + "\n", + " # Forward pass for validation\n", + " val_logits = model(val_inputs, training=True)\n", + " val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n", + "\n", + " # Calculate the validation loss\n", + " val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n", + " validation_loss += val_loss.item()\n", + " num_batches += 1\n", + " num_samples += val_batch['input_ids'].size(0)\n", + "\n", + " average_validation_loss = validation_loss / num_batches\n", + " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", + " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/utils/lora.py b/utils/lora.py index 198f1f3..0e50803 100644 --- a/utils/lora.py +++ b/utils/lora.py @@ -11,7 +11,7 @@ class LinearLayer_LoRA(nn.Module): weight, lora_dim=0, lora_scaling=1, - lora_droppout=0, + lora_dropout=0, bias=None): super(LinearLayer_LoRA, self).__init__() self.weight = weight @@ -29,8 +29,8 @@ class LinearLayer_LoRA(nn.Module): self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) self.lora_scaling = lora_scaling / lora_dim - if lora_droppout > 0: - self.lora_dropout = nn.Dropout(lora_droppout) + if lora_dropout > 0: + self.lora_dropout = nn.Dropout(lora_dropout) else: self.lora_dropout = nn.Identity() @@ -116,15 +116,15 @@ def convert_linear_layer_to_lora(model, part_module_name, lora_dim=0, lora_scaling=1, - lora_droppout=0): - repalce_name = [] + lora_dropout=0): + replace_name = [] for name, module in model.named_modules(): if isinstance(module, nn.Linear) and part_module_name in name: - repalce_name.append(name) - for name in repalce_name: + replace_name.append(name) + for name in replace_name: module = recursive_getattr(model, name) tmp = LinearLayer_LoRA( - module.weight, lora_dim, lora_scaling, lora_droppout, + module.weight, lora_dim, lora_scaling, lora_dropout, module.bias).to(module.weight.device).to(module.weight.dtype) recursive_setattr(model, name, tmp) return model @@ -132,11 +132,11 @@ def convert_linear_layer_to_lora(model, # convert the LoRA layer to linear layer def convert_lora_to_linear_layer(model): - repalce_name = [] + replace_name = [] for name, module in model.named_modules(): if isinstance(module, LinearLayer_LoRA): - repalce_name.append(name) - for name in repalce_name: + replace_name.append(name) + for name in replace_name: module = recursive_getattr(model, name) module.fuse_lora_weight() return model