{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:root:WARNING: Could not find module 'C:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\xformers\\_C.pyd' (or one of its dependencies). Try using the full path with constructor syntax.\n", "Need to compile C++ extensions to get sparse attention suport. Please run python setup.py build develop\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Could not find module 'C:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\Lib\\site-packages\\xformers\\_C.pyd' (or one of its dependencies). Try using the full path with constructor syntax.\n", "\n", "===================================BUG REPORT===================================\n", "Welcome to bitsandbytes. For bug reports, please run\n", "\n", "python -m bitsandbytes\n", "\n", " and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues\n", "================================================================================\n", "bin c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda116.dll\n", "function 'cadam32bit_grad_fp32' not found\n", "CUDA_SETUP: WARNING! libcudart.so not found in any environmental path. Searching in backup paths...\n", "CUDA SETUP: WARNING! libcuda.so not found! Do you have a CUDA driver installed? If you are on a cluster, make sure you are on a CUDA machine!\n", "CUDA SETUP: Loading binary c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\libbitsandbytes_cuda116.dll...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cextension.py:34: UserWarning: The installed version of bitsandbytes was compiled without GPU support. 8-bit optimizers, 8-bit multiplication, and GPU quantization are unavailable.\n", " warn(\"The installed version of bitsandbytes was compiled without GPU support. \"\n", "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('vs/workbench/api/node/extensionHostProcess')}\n", " warn(msg)\n", "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('module'), WindowsPath('/matplotlib_inline.backend_inline')}\n", " warn(msg)\n", "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: The following directories listed in your path were found to be non-existent: {WindowsPath('/usr/local/cuda/lib64')}\n", " warn(msg)\n", "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: No libcudart.so found! Install CUDA or the cudatoolkit package (anaconda)!\n", " warn(msg)\n", "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\bitsandbytes\\cuda_setup\\main.py:149: UserWarning: WARNING: No GPU detected! Check your CUDA paths. Proceeding to load CPU-only library...\n", " warn(msg)\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "import os\n", "import re\n", "import gc\n", "import json\n", "import math\n", "import hashlib\n", "import numpy as np\n", "import logging\n", "import torchaudio\n", "from tqdm.auto import tqdm\n", "import torch.nn.functional as F\n", "from encodec.utils import convert_audio\n", "from accelerate import Accelerator\n", "from accelerate.utils import set_seed\n", "from transformers import BertTokenizer\n", "from huggingface_hub import hf_hub_download\n", "from packaging import version\n", "from diffusers.optimization import get_scheduler\n", "\n", "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n", "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n", "from bark.model import GPTConfig, GPT\n", "from bark.model_fine import FineGPT, FineGPTConfig" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training Args" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "train_batch_size = 8\n", "eval_batch_size = 8\n", "grad_accum = 1\n", "ckpt_path = 'models/fine_2.pt'\n", "model_type = \"fine\"\n", "dataset_path = 'datasets/joe_biden_state_of_union/'\n", "logging_dir = 'logs/'\n", "log_with = 'wandb'\n", "hubert_path = 'data/models/hubert/hubert.pt'\n", "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n", "\n", "output_dir = 'fine_output/'\n", "resume_from_checkpoint = None\n", "\n", "checkpointing_steps = 1000\n", "\n", "mixed_precision = 'bf16'\n", "bits = 16 #4 4 and 8 bit are a work in progress\n", "compute_dtype = torch.bfloat16\n", "double_quant = True\n", "quant_type = 'nf4'\n", "\n", "lora_dim = 64\n", "lora_scaling = 1\n", "lora_dropout = 0.1\n", "lora_module_name = 'transformer.h'\n", "optimize_lora_params_only = True\n", "\n", "learning_rate = 1e-4\n", "scale_lr = False\n", "use_8bit_adam = False\n", "adam_beta1 = 0.9\n", "adam_beta2 = 0.999\n", "adam_epsilon = 1e-8\n", "weight_decay = 0.01\n", "\n", "llm_int8_skip_modules = None\n", "keep_in_fp32_modules = ['lm_head']\n", "\n", "lr_scheduler_type = 'linear'\n", "lr_warmup_steps = 100\n", "num_train_epochs = 5\n", "max_train_steps = None\n", "max_grad_norm = 1.0\n", "\n", "semantic_cross_entropy_loss_weight = 0\n", "\n", "seed = 741" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Define Functions" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "c:\\Users\\labou\\AppData\\Local\\Programs\\Python\\Python310\\lib\\site-packages\\accelerate\\accelerator.py:258: FutureWarning: `logging_dir` is deprecated and will be removed in version 0.18.0 of 🤗 Accelerate. Use `project_dir` instead.\n", " warnings.warn(\n" ] } ], "source": [ "CONTEXT_WINDOW_SIZE = 1024\n", "\n", "MAX_SEMANTIC_LEN = 256\n", "\n", "SEMANTIC_RATE_HZ = 49.9\n", "SEMANTIC_VOCAB_SIZE = 10_000\n", "\n", "TEXT_ENCODING_OFFSET = 10_048\n", "SEMANTIC_PAD_TOKEN = 10_000\n", "TEXT_PAD_TOKEN = 129_595\n", "SEMANTIC_INFER_TOKEN = 129_599\n", "\n", "MAX_COARSE_LEN = 768\n", "\n", "SAMPLE_RATE = 24_000\n", "CHANNELS = 1\n", "\n", "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n", "COARSE_INFER_TOKEN = 12_050\n", "\n", "CODEBOOK_SIZE = 1024\n", "N_COARSE_CODEBOOKS = 2\n", "N_FINE_CODEBOOKS = 8\n", "COARSE_RATE_HZ = 75\n", "\n", "logger = logging.getLogger(__name__)\n", "\n", "\n", "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n", "\n", "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n", "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n", "\n", "\n", "def _clear_cuda_cache():\n", " if torch.cuda.is_available():\n", " torch.cuda.empty_cache()\n", " torch.cuda.synchronize()\n", "\n", "\n", "def _md5(fname):\n", " hash_md5 = hashlib.md5()\n", " with open(fname, \"rb\") as f:\n", " for chunk in iter(lambda: f.read(4096), b\"\"):\n", " hash_md5.update(chunk)\n", " return hash_md5.hexdigest()\n", "\n", "\n", "def _download(from_hf_path, file_name, to_local_path):\n", " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n", " path = '/'.join(to_local_path.split(\"/\")[:-1])\n", " os.makedirs(path, exist_ok=True)\n", " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n", " os.replace(os.path.join(path, file_name), to_local_path)\n", "\n", "\n", "def _tokenize(tokenizer, text):\n", " return tokenizer.encode(text, add_special_tokens=False)\n", "\n", "\n", "def _detokenize(tokenizer, enc_text):\n", " return tokenizer.decode(enc_text)\n", "\n", "\n", "def _normalize_whitespace(text):\n", " return re.sub(r\"\\s+\", \" \", text).strip()\n", "\n", "\n", "REMOTE_MODEL_PATHS = {\n", " \"text_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"text.pt\",\n", " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n", " },\n", " \"coarse_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"coarse.pt\",\n", " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n", " },\n", " \"fine_small\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"fine.pt\",\n", " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n", " },\n", " \"text\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"text_2.pt\",\n", " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n", " },\n", " \"coarse\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"coarse_2.pt\",\n", " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n", " },\n", " \"fine\": {\n", " \"repo_id\": \"suno/bark\",\n", " \"file_name\": \"fine_2.pt\",\n", " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n", " },\n", "}\n", "\n", "\n", "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n", " if model_type == \"text\":\n", " ConfigClass = GPTConfig\n", " ModelClass = GPT\n", " elif model_type == \"coarse\":\n", " ConfigClass = GPTConfig\n", " ModelClass = GPT\n", " elif model_type == \"fine\":\n", " ConfigClass = FineGPTConfig\n", " ModelClass = FineGPT\n", " else:\n", " raise NotImplementedError()\n", " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n", " model_info = REMOTE_MODEL_PATHS[model_key]\n", " if ckpt_path in [None, '']:\n", " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n", " if not os.path.exists(ckpt_path):\n", " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n", " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n", " checkpoint = torch.load(ckpt_path, map_location=device)\n", " # this is a hack\n", " model_args = checkpoint[\"model_args\"]\n", " if \"input_vocab_size\" not in model_args:\n", " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n", " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n", " del model_args[\"vocab_size\"]\n", " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n", " model = ModelClass(gptconf)\n", " state_dict = checkpoint[\"model\"]\n", " # fixup checkpoint\n", " unwanted_prefix = \"_orig_mod.\"\n", " for k, v in list(state_dict.items()):\n", " if k.startswith(unwanted_prefix):\n", " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n", " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n", " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n", " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n", " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n", " if len(extra_keys) != 0:\n", " raise ValueError(f\"extra keys found: {extra_keys}\")\n", " if len(missing_keys) != 0:\n", " raise ValueError(f\"missing keys: {missing_keys}\")\n", " model.load_state_dict(state_dict, strict=False)\n", " n_params = model.get_num_params()\n", " val_loss = checkpoint[\"best_val_loss\"].item()\n", " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n", " del checkpoint, state_dict\n", " _clear_cuda_cache()\n", " if model_type == \"text\":\n", " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n", " return model, tokenizer\n", " return model\n", "\n", "\n", "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n", " assert len(arr.shape) == 2\n", " arr = arr.copy()\n", " if offset_size is not None:\n", " for n in range(1, arr.shape[0]):\n", " arr[n, :] += offset_size * n\n", " flat_arr = arr.ravel(\"F\")\n", " return flat_arr\n", "\n", "\n", "def load_filepaths_and_text(filename, split=\"|\"):\n", " with open(filename, encoding='utf-8') as f:\n", " filepaths_and_text = [line.strip().split(split) for line in f]\n", " base = os.path.dirname(filename)\n", " for j in range(len(filepaths_and_text)):\n", " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n", " return filepaths_and_text\n", "\n", "\n", "class TtsDataset(torch.utils.data.Dataset):\n", " def __init__(self, opt):\n", " self.path = os.path.dirname(opt['path'])\n", " self.mode = opt['mode']\n", " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n", "\n", " def __getitem__(self, index):\n", " audiopath_and_text = self.audiopaths_and_text[index]\n", " audiopath = audiopath_and_text[0]\n", "\n", " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n", " fine_tokens = tokens['fine']\n", "\n", " return torch.from_numpy(fine_tokens)\n", "\n", " def __len__(self):\n", " return len(self.audiopaths_and_text)\n", "\n", "\n", "class TtsCollater():\n", " def __init__(self):\n", " pass\n", " def __call__(self, batch):\n", " max_len = 1024\n", " fine_tokens = []\n", "\n", " for fine_tokens_ in batch:\n", " if fine_tokens_.shape[1] > max_len:\n", " start_idx = np.random.randint(0, fine_tokens_.shape[1] - max_len + 1)\n", " fine_tokens_ = fine_tokens_[:, start_idx : start_idx + max_len]\n", "\n", " pad_size = max_len - fine_tokens_.shape[1]\n", " fine_tokens_ = F.pad(fine_tokens_, (0, pad_size), value=CODEBOOK_SIZE)\n", "\n", " fine_tokens_ = fine_tokens_.T\n", "\n", " fine_tokens.append(fine_tokens_)\n", "\n", " return {'fine_tokens': torch.stack(fine_tokens).contiguous()}\n", " \n", "\n", "accelerator = Accelerator(\n", " gradient_accumulation_steps=grad_accum,\n", " mixed_precision=mixed_precision,\n", " log_with=log_with,\n", " logging_dir=logging_dir,\n", ")\n", "device = accelerator.device\n", "\n", "os.makedirs(output_dir, exist_ok=True)\n", "\n", "set_seed(seed)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Setup Dataset (only need to do this once)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n", "\n", "# path = dataset_path\n", "\n", "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", "# from hubert.hubert_manager import HuBERTManager\n", "# hubert_manager = HuBERTManager()\n", "# from hubert.pre_kmeans_hubert import CustomHubert\n", "# from hubert.customtokenizer import CustomTokenizer\n", "# hubert_manager.make_sure_hubert_installed()\n", "# hubert_manager.make_sure_tokenizer_installed()\n", "\n", "# # Load the HuBERT model\n", "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n", "# hubert_model.eval()\n", "# for param in hubert_model.parameters():\n", "# param.requires_grad = False\n", "\n", "# # Load the CustomTokenizer model\n", "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n", "\n", "# from bark.generation import load_codec_model\n", "# codec_model = load_codec_model(use_gpu=True)\n", "# codec_model.eval()\n", "# for param in codec_model.parameters():\n", "# param.requires_grad = False\n", "\n", "\n", "# def get_duration(wav, sr):\n", "# return wav.shape[1] / sr\n", "\n", "# valid_lines_train = []\n", "# # convert wavs to semantic tokens\n", "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n", "# wav, sr = torchaudio.load(wav_path)\n", "# if not get_duration(wav, sr) > max_duration_sec:\n", "# valid_lines_train.append((wav_path, txt))\n", "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", "\n", "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", "\n", "# # save semantic tokens\n", "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", "# semantic_tokens = semantic_tokens.cpu().numpy()\n", "\n", "# # Extract discrete codes from EnCodec\n", "# with torch.no_grad():\n", "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", "\n", "# # move codes to cpu\n", "# codes = codes.cpu().numpy()\n", "\n", "# # save tokens\n", "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", "\n", "# # rewrite train.txt with valid lines\n", "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n", "# for wav_path, txt in valid_lines_train:\n", "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", "# f.write(f'{wav_path}|{txt}\\n')\n", "\n", "# valid_lines_valid = []\n", "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n", "# wav, sr = torchaudio.load(wav_path)\n", "# if not get_duration(wav, sr) > max_duration_sec:\n", "# valid_lines_valid.append((wav_path, txt))\n", "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n", "\n", "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n", "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n", "\n", "# # save semantic tokens\n", "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n", "# semantic_tokens = semantic_tokens.cpu().numpy()\n", " \n", "# # Extract discrete codes from EnCodec\n", "# with torch.no_grad():\n", "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n", "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n", "\n", "# # move codes to cpu\n", "# codes = codes.cpu().numpy()\n", "\n", "# # save tokens\n", "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n", "\n", "# # rewrite valid.txt with valid lines\n", "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n", "# for wav_path, txt in valid_lines_valid:\n", "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n", "# f.write(f'{wav_path}|{txt}\\n')\n", "\n", "# del hubert_model\n", "# del hubert_tokenizer\n", "# del codec_model\n", "# gc.collect()\n", "# torch.cuda.empty_cache()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded fine model with 302090240 params, val_loss=2.0786.\n" ] } ], "source": [ "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "if scale_lr:\n", " learning_rate = (\n", " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n", " )\n", "\n", "if use_8bit_adam:\n", " try:\n", " import bitsandbytes as bnb\n", " except ImportError:\n", " raise ImportError(\n", " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n", " )\n", "\n", " optimizer_class = bnb.optim.AdamW8bit\n", "else:\n", " optimizer_class = torch.optim.AdamW" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "quantization_config=BitsAndBytesConfig(\n", " load_in_4bit=bits == 4,\n", " load_in_8bit=bits == 8,\n", " llm_int8_threshold=6.0,\n", " llm_int8_has_fp16_weight=False,\n", " bnb_4bit_compute_dtype=compute_dtype,\n", " bnb_4bit_use_double_quant=double_quant,\n", " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n", ")\n", "\n", "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n", "# if quantization_config.load_in_8bit:\n", "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n", "# elif quantization_config.load_in_4bit:\n", "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n", "\n", "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n", "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n", "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n", "# else:\n", "# modules_to_not_convert = llm_int8_skip_modules\n", "\n", "# if not isinstance(modules_to_not_convert, list):\n", "# modules_to_not_convert = [modules_to_not_convert]\n", "\n", "# modules_to_not_convert.extend(keep_in_fp32_modules)\n", "\n", "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n", "\n", "# if quantization_config.load_in_4bit and not supports_4bit:\n", "# raise ValueError(\n", "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n", "# \" make sure you have the latest version of `bitsandbytes` installed\"\n", "# )\n", " \n", "# if len(modules_to_not_convert) == 0:\n", "# modules_to_not_convert = None\n", "\n", "# model = replace_with_bnb_linear(\n", "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n", "# )\n", "\n", "# # training in 8-bit is only available in 0.37.0+\n", "# model._is_kbit_training_enabled = version.parse(\n", "# importlib_metadata.version(\"bitsandbytes\")\n", "# ) >= version.parse(\"0.37.0\")\n", "\n", "# model.config.quantization_config = quantization_config" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "if bits == 4:\n", " from accelerate.utils import CustomDtype\n", " target_dtype = CustomDtype.INT4\n", "elif bits == 8:\n", " target_dtype = torch.int8\n", "\n", "if lora_dim > 0:\n", " for param in model.parameters():\n", " if param.ndim == 1:\n", " # cast the small parameters (e.g. layernorm) to fp32 for stability\n", " param.data = param.data.to(torch.float32)\n", " \n", " class CastOutputToFloat(nn.Sequential):\n", " def forward(self, x):\n", " return super().forward(x).to(torch.float32)\n", "\n", " # model.lm_head = CastOutputToFloat(model.lm_head)\n", " for i, lm_head in enumerate(model.lm_heads):\n", " model.lm_heads[i] = CastOutputToFloat(lm_head)\n", "\n", " model = convert_linear_layer_to_lora(model, lora_module_name,\n", " lora_dim=lora_dim, lora_scaling=lora_scaling,\n", " lora_dropout=lora_dropout)\n", " if optimize_lora_params_only:\n", " model = only_optimize_lora_parameters(model)" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "params_to_optimize = (\n", " param for param in model.parameters() if param.requires_grad\n", " )\n", "\n", "optimizer = optimizer_class(\n", " params_to_optimize,\n", " lr=learning_rate,\n", " betas=(adam_beta1, adam_beta2),\n", " weight_decay=weight_decay,\n", " eps=adam_epsilon,\n", ")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "opt_train = {\n", " 'path': dataset_path,\n", " 'mode': 'train',\n", "}\n", "\n", "opt_val = {\n", " 'path': dataset_path,\n", " 'mode': 'valid',\n", "}\n", "\n", "train_dataset = TtsDataset(opt_train)\n", "validation_dataset = TtsDataset(opt_val)\n", "\n", "train_dataloader = torch.utils.data.DataLoader(\n", " train_dataset,\n", " batch_size=train_batch_size,\n", " collate_fn=TtsCollater(),\n", ")\n", "\n", "validation_dataloader = torch.utils.data.DataLoader(\n", " validation_dataset,\n", " batch_size=eval_batch_size,\n", " collate_fn=TtsCollater(),\n", ")\n", "\n", "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n", "\n", "# Scheduler and math around the number of training steps.\n", "overrode_max_train_steps = False\n", "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", "if max_train_steps is None:\n", " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", " overrode_max_train_steps = True\n", "\n", "lr_scheduler = get_scheduler(\n", " lr_scheduler_type,\n", " optimizer=optimizer,\n", " num_warmup_steps=lr_warmup_steps * grad_accum,\n", " num_training_steps=max_train_steps * grad_accum,\n", ")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n", " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n", ")\n", "accelerator.register_for_checkpointing(lr_scheduler)\n", "\n", "weight_dtype = torch.float32\n", "if accelerator.mixed_precision == \"fp16\":\n", " weight_dtype = torch.float16\n", "elif accelerator.mixed_precision == \"bf16\":\n", " weight_dtype = torch.bfloat16" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "ERROR:wandb.jupyter:Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mfrancislabounty\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "wandb version 0.15.4 is available! To upgrade, please run:\n", " $ pip install wandb --upgrade" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Tracking run with wandb version 0.13.6" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in e:\\Python\\bark-with-voice-clone\\wandb\\run-20230629_202416-290ebk11" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run fresh-pyramid-26 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n", "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n", "if overrode_max_train_steps:\n", " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n", "# Afterwards we recalculate our number of training epochs\n", "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n", "\n", "# We need to initialize the trackers we use, and also store our configuration.\n", "# The trackers initializes automatically on the main process.\n", "if accelerator.is_main_process:\n", " accelerator.init_trackers(\"bark_coarse\", config={})\n", "\n", "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n", "logger.info(\"***** Running training *****\")\n", "logger.info(f\" Num examples = {len(train_dataset)}\")\n", "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n", "logger.info(f\" Num Epochs = {num_train_epochs}\")\n", "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n", "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n", "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n", "logger.info(f\" Total optimization steps = {max_train_steps}\")\n", "global_step = 0\n", "first_epoch = 0\n", "\n", "if resume_from_checkpoint:\n", " if resume_from_checkpoint != \"latest\":\n", " path = os.path.basename(resume_from_checkpoint)\n", " else:\n", " # Get the most recent checkpoint\n", " dirs = os.listdir(output_dir)\n", " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n", " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n", " path = dirs[-1]\n", " accelerator.print(f\"Resuming from checkpoint {path}\")\n", " accelerator.load_state(os.path.join(output_dir, path))\n", " global_step = int(path.split(\"-\")[1])\n", "\n", " resume_global_step = global_step * grad_accum\n", " first_epoch = resume_global_step // num_update_steps_per_epoch\n", " resume_step = resume_global_step % num_update_steps_per_epoch\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 30.702054630626332 over 82 samples and 11 batches.\n" ] } ], "source": [ "if accelerator.is_main_process:\n", " model.eval()\n", " validation_loss = 0.0\n", " num_batches = 0\n", " num_samples = 0\n", " with torch.no_grad():\n", " for val_step, val_batch in enumerate(validation_dataloader):\n", " # Similar to training, process the validation batch\n", " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n", " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n", " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n", " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n", "\n", " # Forward pass for validation\n", " logits_7 = model(6, fine_tokens_input_7)\n", " logits_8 = model(7, fine_tokens_input_8)\n", "\n", " # Calculate the validation loss\n", " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", "\n", " loss = loss_7 + loss_8\n", " validation_loss += loss.item()\n", " num_batches += 1\n", " num_samples += val_batch['fine_tokens'].size(0)\n", "\n", " average_validation_loss = validation_loss / num_batches\n", " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3761107b0c094d2db6532410a582408c", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/205 [00:00(success)." ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "45d15ef5bc1e4729aaf827aa3380823d", "version_major": 2, "version_minor": 0 }, "text/plain": [ "VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.0, max…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "\n", "

Run history:


loss█▆█▅▅▇▆▆▆▇▆▆▅▄▃▃▄▂▃▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇███▇▇▇▆▆▆▅▅▅▄▄▄▃▃▂▂▂▁▁

Run summary:


loss3.18231
lr0.0

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Synced fresh-pyramid-26: https://wandb.ai/francislabounty/bark_coarse/runs/290ebk11
Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Find logs at: .\\wandb\\run-20230629_202416-290ebk11\\logs" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Only show the progress bar once on each machine.\n", "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n", "progress_bar.set_description(\"Steps\")\n", "\n", "for epoch in range(first_epoch, num_train_epochs):\n", " model.train()\n", " for step, batch in enumerate(train_dataloader):\n", " # Skip steps until we reach the resumed step\n", " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n", " if step % grad_accum == 0:\n", " progress_bar.update(1)\n", " continue\n", "\n", " with accelerator.accumulate(model):\n", " fine_targets_7 = batch['fine_tokens'][:, :, 6]\n", " fine_tokens_input_7 = torch.cat([batch['fine_tokens'][:, :, :6], torch.zeros_like(batch['fine_tokens'][:, :, 6:])], dim=2)\n", " fine_targets_8 = batch['fine_tokens'][:, :, 7]\n", " fine_tokens_input_8 = torch.cat([batch['fine_tokens'][:, :, :7], torch.zeros_like(batch['fine_tokens'][:, :, 7:])], dim=2)\n", "\n", " # Forward pass\n", " logits_7 = model(6, fine_tokens_input_7)\n", " logits_8 = model(7, fine_tokens_input_8)\n", "\n", " # Calculate the loss\n", " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", "\n", " loss = loss_7 + loss_8\n", "\n", " accelerator.backward(loss)\n", " if accelerator.sync_gradients:\n", " params_to_clip = (\n", " param for param in model.parameters() if param.requires_grad\n", " )\n", " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " # Checks if the accelerator has performed an optimization step behind the scenes\n", " if accelerator.sync_gradients:\n", " progress_bar.update(1)\n", " global_step += 1\n", "\n", " if global_step % checkpointing_steps == 0:\n", " if accelerator.is_main_process:\n", " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n", " accelerator.save_state(save_path)\n", " logger.info(f\"Saved state to {save_path}\")\n", "\n", " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", "\n", " if global_step >= max_train_steps:\n", " break\n", " \n", " accelerator.wait_for_everyone()\n", "\n", "if accelerator.is_main_process:\n", " if lora_dim > 0:\n", " model = convert_lora_to_linear_layer(model)\n", " # save model\n", " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n", " \n", " config = model.config.__dict__\n", " # save config\n", " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n", " json.dump(config, f, indent=2)\n", "\n", "accelerator.end_training()" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# Validation" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation Loss: 3.041703635996038 over 82 samples and 11 batches.\n" ] } ], "source": [ "if accelerator.is_main_process:\n", " model.eval()\n", " validation_loss = 0.0\n", " num_batches = 0\n", " num_samples = 0\n", " with torch.no_grad():\n", " for val_step, val_batch in enumerate(validation_dataloader):\n", " # Similar to training, process the validation batch\n", " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n", " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n", " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n", " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n", "\n", " # Forward pass for validation\n", " logits_7 = model(6, fine_tokens_input_7)\n", " logits_8 = model(7, fine_tokens_input_8)\n", "\n", " # Calculate the validation loss\n", " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n", " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n", "\n", " loss = loss_7 + loss_8\n", " validation_loss += loss.item()\n", " num_batches += 1\n", " num_samples += val_batch['fine_tokens'].size(0)\n", "\n", " average_validation_loss = validation_loss / num_batches\n", " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n", " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.8" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }