mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-14 18:57:56 +01:00
920 lines
35 KiB
Plaintext
920 lines
35 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import os\n",
|
|
"import re\n",
|
|
"import gc\n",
|
|
"import json\n",
|
|
"import math\n",
|
|
"import hashlib\n",
|
|
"import numpy as np\n",
|
|
"import logging\n",
|
|
"import torchaudio\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from encodec.utils import convert_audio\n",
|
|
"from accelerate import Accelerator\n",
|
|
"from accelerate.utils import set_seed\n",
|
|
"from transformers import BertTokenizer\n",
|
|
"from huggingface_hub import hf_hub_download\n",
|
|
"from packaging import version\n",
|
|
"from diffusers.optimization import get_scheduler\n",
|
|
"\n",
|
|
"from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
|
|
"from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
|
|
"from bark.model import GPTConfig, GPT\n",
|
|
"from bark.model_fine import FineGPT, FineGPTConfig"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Training Args"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_batch_size = 8\n",
|
|
"eval_batch_size = 8\n",
|
|
"grad_accum = 2\n",
|
|
"ckpt_path = 'models/fine_2.pt'\n",
|
|
"model_type = \"fine\"\n",
|
|
"dataset_path = 'datasets/joe_biden_state_of_union/'\n",
|
|
"logging_dir = 'logs/'\n",
|
|
"log_with = 'wandb'\n",
|
|
"hubert_path = 'data/models/hubert/hubert.pt'\n",
|
|
"hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
|
|
"\n",
|
|
"output_dir = 'fine_output/'\n",
|
|
"resume_from_checkpoint = None\n",
|
|
"\n",
|
|
"checkpointing_steps = 1000\n",
|
|
"\n",
|
|
"mixed_precision = 'bf16'\n",
|
|
"bits = 16 #4 4 and 8 bit are a work in progress\n",
|
|
"compute_dtype = torch.bfloat16\n",
|
|
"double_quant = True\n",
|
|
"quant_type = 'nf4'\n",
|
|
"\n",
|
|
"lora_dim = 64\n",
|
|
"lora_scaling = 1\n",
|
|
"lora_dropout = 0.1\n",
|
|
"lora_module_name = 'transformer.h'\n",
|
|
"optimize_lora_params_only = False\n",
|
|
"\n",
|
|
"learning_rate = 1e-4\n",
|
|
"scale_lr = False\n",
|
|
"use_8bit_adam = False\n",
|
|
"adam_beta1 = 0.9\n",
|
|
"adam_beta2 = 0.999\n",
|
|
"adam_epsilon = 1e-8\n",
|
|
"weight_decay = 0.01\n",
|
|
"\n",
|
|
"llm_int8_skip_modules = None\n",
|
|
"keep_in_fp32_modules = ['lm_head']\n",
|
|
"\n",
|
|
"lr_scheduler_type = 'linear'\n",
|
|
"lr_warmup_steps = 60\n",
|
|
"num_train_epochs = 5\n",
|
|
"max_train_steps = None\n",
|
|
"max_grad_norm = 1.0\n",
|
|
"\n",
|
|
"seed = 741"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Define Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"CONTEXT_WINDOW_SIZE = 1024\n",
|
|
"\n",
|
|
"MAX_SEMANTIC_LEN = 256\n",
|
|
"\n",
|
|
"SEMANTIC_RATE_HZ = 49.9\n",
|
|
"SEMANTIC_VOCAB_SIZE = 10_000\n",
|
|
"\n",
|
|
"TEXT_ENCODING_OFFSET = 10_048\n",
|
|
"SEMANTIC_PAD_TOKEN = 10_000\n",
|
|
"TEXT_PAD_TOKEN = 129_595\n",
|
|
"SEMANTIC_INFER_TOKEN = 129_599\n",
|
|
"\n",
|
|
"MAX_COARSE_LEN = 768\n",
|
|
"\n",
|
|
"SAMPLE_RATE = 24_000\n",
|
|
"CHANNELS = 1\n",
|
|
"\n",
|
|
"COARSE_SEMANTIC_PAD_TOKEN = 12_048\n",
|
|
"COARSE_INFER_TOKEN = 12_050\n",
|
|
"\n",
|
|
"CODEBOOK_SIZE = 1024\n",
|
|
"N_COARSE_CODEBOOKS = 2\n",
|
|
"N_FINE_CODEBOOKS = 8\n",
|
|
"COARSE_RATE_HZ = 75\n",
|
|
"\n",
|
|
"logger = logging.getLogger(__name__)\n",
|
|
"\n",
|
|
"\n",
|
|
"USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
|
|
"\n",
|
|
"default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
|
|
"CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def _clear_cuda_cache():\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"\n",
|
|
"def _md5(fname):\n",
|
|
" hash_md5 = hashlib.md5()\n",
|
|
" with open(fname, \"rb\") as f:\n",
|
|
" for chunk in iter(lambda: f.read(4096), b\"\"):\n",
|
|
" hash_md5.update(chunk)\n",
|
|
" return hash_md5.hexdigest()\n",
|
|
"\n",
|
|
"\n",
|
|
"def _download(from_hf_path, file_name, to_local_path):\n",
|
|
" to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
|
|
" path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
|
|
" os.makedirs(path, exist_ok=True)\n",
|
|
" hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
|
|
" os.replace(os.path.join(path, file_name), to_local_path)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _tokenize(tokenizer, text):\n",
|
|
" return tokenizer.encode(text, add_special_tokens=False)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _detokenize(tokenizer, enc_text):\n",
|
|
" return tokenizer.decode(enc_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _normalize_whitespace(text):\n",
|
|
" return re.sub(r\"\\s+\", \" \", text).strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"REMOTE_MODEL_PATHS = {\n",
|
|
" \"text_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"text.pt\",\n",
|
|
" \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
|
|
" },\n",
|
|
" \"coarse_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"coarse.pt\",\n",
|
|
" \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
|
|
" },\n",
|
|
" \"fine_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"fine.pt\",\n",
|
|
" \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
|
|
" },\n",
|
|
" \"text\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"text_2.pt\",\n",
|
|
" \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
|
|
" },\n",
|
|
" \"coarse\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"coarse_2.pt\",\n",
|
|
" \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
|
|
" },\n",
|
|
" \"fine\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"fine_2.pt\",\n",
|
|
" \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
|
|
" },\n",
|
|
"}\n",
|
|
"\n",
|
|
"\n",
|
|
"def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
|
|
" if model_type == \"text\":\n",
|
|
" ConfigClass = GPTConfig\n",
|
|
" ModelClass = GPT\n",
|
|
" elif model_type == \"coarse\":\n",
|
|
" ConfigClass = GPTConfig\n",
|
|
" ModelClass = GPT\n",
|
|
" elif model_type == \"fine\":\n",
|
|
" ConfigClass = FineGPTConfig\n",
|
|
" ModelClass = FineGPT\n",
|
|
" else:\n",
|
|
" raise NotImplementedError()\n",
|
|
" model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
|
|
" model_info = REMOTE_MODEL_PATHS[model_key]\n",
|
|
" if ckpt_path in [None, '']:\n",
|
|
" ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
|
|
" if not os.path.exists(ckpt_path):\n",
|
|
" logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
|
|
" _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
|
|
" checkpoint = torch.load(ckpt_path, map_location=device)\n",
|
|
" # this is a hack\n",
|
|
" model_args = checkpoint[\"model_args\"]\n",
|
|
" if \"input_vocab_size\" not in model_args:\n",
|
|
" model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
|
|
" model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
|
|
" del model_args[\"vocab_size\"]\n",
|
|
" gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
|
|
" model = ModelClass(gptconf)\n",
|
|
" state_dict = checkpoint[\"model\"]\n",
|
|
" # fixup checkpoint\n",
|
|
" unwanted_prefix = \"_orig_mod.\"\n",
|
|
" for k, v in list(state_dict.items()):\n",
|
|
" if k.startswith(unwanted_prefix):\n",
|
|
" state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
|
|
" extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
|
|
" extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
|
|
" missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
|
|
" missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
|
|
" if len(extra_keys) != 0:\n",
|
|
" raise ValueError(f\"extra keys found: {extra_keys}\")\n",
|
|
" if len(missing_keys) != 0:\n",
|
|
" raise ValueError(f\"missing keys: {missing_keys}\")\n",
|
|
" model.load_state_dict(state_dict, strict=False)\n",
|
|
" n_params = model.get_num_params()\n",
|
|
" val_loss = checkpoint[\"best_val_loss\"].item()\n",
|
|
" print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
|
|
" del checkpoint, state_dict\n",
|
|
" _clear_cuda_cache()\n",
|
|
" if model_type == \"text\":\n",
|
|
" tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
|
|
" return model, tokenizer\n",
|
|
" return model\n",
|
|
"\n",
|
|
"\n",
|
|
"def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n",
|
|
" assert len(arr.shape) == 2\n",
|
|
" arr = arr.copy()\n",
|
|
" if offset_size is not None:\n",
|
|
" for n in range(1, arr.shape[0]):\n",
|
|
" arr[n, :] += offset_size * n\n",
|
|
" flat_arr = arr.ravel(\"F\")\n",
|
|
" return flat_arr\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_filepaths_and_text(filename, split=\"|\"):\n",
|
|
" with open(filename, encoding='utf-8', errors='ignore') as f:\n",
|
|
" filepaths_and_text = [line.strip().split(split) for line in f]\n",
|
|
" base = os.path.dirname(filename)\n",
|
|
" for j in range(len(filepaths_and_text)):\n",
|
|
" filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
|
|
" return filepaths_and_text\n",
|
|
"\n",
|
|
"\n",
|
|
"class TtsDataset(torch.utils.data.Dataset):\n",
|
|
" def __init__(self, opt):\n",
|
|
" self.path = os.path.dirname(opt['path'])\n",
|
|
" self.mode = opt['mode']\n",
|
|
" self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n",
|
|
"\n",
|
|
" def __getitem__(self, index):\n",
|
|
" audiopath_and_text = self.audiopaths_and_text[index]\n",
|
|
" audiopath = audiopath_and_text[0]\n",
|
|
"\n",
|
|
" tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
|
|
" fine_tokens = tokens['fine']\n",
|
|
"\n",
|
|
" return torch.from_numpy(fine_tokens)\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.audiopaths_and_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"class TtsCollater():\n",
|
|
" def __init__(self):\n",
|
|
" pass\n",
|
|
" def __call__(self, batch):\n",
|
|
" max_len = 1024\n",
|
|
" fine_tokens = []\n",
|
|
"\n",
|
|
" for fine_tokens_ in batch:\n",
|
|
" if fine_tokens_.shape[1] > max_len:\n",
|
|
" start_idx = np.random.randint(0, fine_tokens_.shape[1] - max_len + 1)\n",
|
|
" fine_tokens_ = fine_tokens_[:, start_idx : start_idx + max_len]\n",
|
|
"\n",
|
|
" pad_size = max_len - fine_tokens_.shape[1]\n",
|
|
" fine_tokens_ = F.pad(fine_tokens_, (0, pad_size), value=CODEBOOK_SIZE)\n",
|
|
"\n",
|
|
" fine_tokens_ = fine_tokens_.T\n",
|
|
"\n",
|
|
" fine_tokens.append(fine_tokens_)\n",
|
|
"\n",
|
|
" return {'fine_tokens': torch.stack(fine_tokens).contiguous()}\n",
|
|
" \n",
|
|
"\n",
|
|
"accelerator = Accelerator(\n",
|
|
" gradient_accumulation_steps=grad_accum,\n",
|
|
" mixed_precision=mixed_precision,\n",
|
|
" log_with=log_with,\n",
|
|
" logging_dir=logging_dir,\n",
|
|
")\n",
|
|
"device = accelerator.device\n",
|
|
"\n",
|
|
"os.makedirs(output_dir, exist_ok=True)\n",
|
|
"\n",
|
|
"set_seed(seed)"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Setup Dataset (only need to do this once)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
|
|
"\n",
|
|
"# path = dataset_path\n",
|
|
"\n",
|
|
"# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
|
|
"# from hubert.hubert_manager import HuBERTManager\n",
|
|
"# hubert_manager = HuBERTManager()\n",
|
|
"# from hubert.pre_kmeans_hubert import CustomHubert\n",
|
|
"# from hubert.customtokenizer import CustomTokenizer\n",
|
|
"# hubert_manager.make_sure_hubert_installed()\n",
|
|
"# hubert_manager.make_sure_tokenizer_installed()\n",
|
|
"\n",
|
|
"# # Load the HuBERT model\n",
|
|
"# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
|
|
"# hubert_model.eval()\n",
|
|
"# for param in hubert_model.parameters():\n",
|
|
"# param.requires_grad = False\n",
|
|
"\n",
|
|
"# # Load the CustomTokenizer model\n",
|
|
"# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
|
|
"\n",
|
|
"# from bark.generation import load_codec_model\n",
|
|
"# codec_model = load_codec_model(use_gpu=True)\n",
|
|
"# codec_model.eval()\n",
|
|
"# for param in codec_model.parameters():\n",
|
|
"# param.requires_grad = False\n",
|
|
"\n",
|
|
"\n",
|
|
"# def get_duration(wav, sr):\n",
|
|
"# return wav.shape[1] / sr\n",
|
|
"\n",
|
|
"# valid_lines_train = []\n",
|
|
"# # convert wavs to semantic tokens\n",
|
|
"# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
|
|
"# wav, sr = torchaudio.load(wav_path)\n",
|
|
"# if not get_duration(wav, sr) > max_duration_sec:\n",
|
|
"# valid_lines_train.append((wav_path, txt))\n",
|
|
"# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
|
|
"\n",
|
|
"# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
|
|
"# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
|
|
"\n",
|
|
"# # save semantic tokens\n",
|
|
"# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
|
|
"# semantic_tokens = semantic_tokens.cpu().numpy()\n",
|
|
"\n",
|
|
"# # Extract discrete codes from EnCodec\n",
|
|
"# with torch.no_grad():\n",
|
|
"# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
|
|
"# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
|
|
"\n",
|
|
"# # move codes to cpu\n",
|
|
"# codes = codes.cpu().numpy()\n",
|
|
"\n",
|
|
"# # save tokens\n",
|
|
"# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
|
|
"\n",
|
|
"# # rewrite train.txt with valid lines\n",
|
|
"# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
|
|
"# for wav_path, txt in valid_lines_train:\n",
|
|
"# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
|
|
"# f.write(f'{wav_path}|{txt}\\n')\n",
|
|
"\n",
|
|
"# valid_lines_valid = []\n",
|
|
"# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
|
|
"# wav, sr = torchaudio.load(wav_path)\n",
|
|
"# if not get_duration(wav, sr) > max_duration_sec:\n",
|
|
"# valid_lines_valid.append((wav_path, txt))\n",
|
|
"# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
|
|
"\n",
|
|
"# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
|
|
"# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
|
|
"\n",
|
|
"# # save semantic tokens\n",
|
|
"# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
|
|
"# semantic_tokens = semantic_tokens.cpu().numpy()\n",
|
|
" \n",
|
|
"# # Extract discrete codes from EnCodec\n",
|
|
"# with torch.no_grad():\n",
|
|
"# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
|
|
"# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
|
|
"\n",
|
|
"# # move codes to cpu\n",
|
|
"# codes = codes.cpu().numpy()\n",
|
|
"\n",
|
|
"# # save tokens\n",
|
|
"# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
|
|
"\n",
|
|
"# # rewrite valid.txt with valid lines\n",
|
|
"# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
|
|
"# for wav_path, txt in valid_lines_valid:\n",
|
|
"# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
|
|
"# f.write(f'{wav_path}|{txt}\\n')\n",
|
|
"\n",
|
|
"# del hubert_model\n",
|
|
"# del hubert_tokenizer\n",
|
|
"# del codec_model\n",
|
|
"# gc.collect()\n",
|
|
"# torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if scale_lr:\n",
|
|
" learning_rate = (\n",
|
|
" learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
|
|
" )\n",
|
|
"\n",
|
|
"if use_8bit_adam:\n",
|
|
" try:\n",
|
|
" import bitsandbytes as bnb\n",
|
|
" except ImportError:\n",
|
|
" raise ImportError(\n",
|
|
" \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" optimizer_class = bnb.optim.AdamW8bit\n",
|
|
"else:\n",
|
|
" optimizer_class = torch.optim.AdamW"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"quantization_config=BitsAndBytesConfig(\n",
|
|
" load_in_4bit=bits == 4,\n",
|
|
" load_in_8bit=bits == 8,\n",
|
|
" llm_int8_threshold=6.0,\n",
|
|
" llm_int8_has_fp16_weight=False,\n",
|
|
" bnb_4bit_compute_dtype=compute_dtype,\n",
|
|
" bnb_4bit_use_double_quant=double_quant,\n",
|
|
" bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
|
|
")\n",
|
|
"\n",
|
|
"# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
|
|
"# if quantization_config.load_in_8bit:\n",
|
|
"# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
|
|
"# elif quantization_config.load_in_4bit:\n",
|
|
"# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
|
|
"\n",
|
|
"# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
|
|
"# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
|
|
"# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
|
|
"# else:\n",
|
|
"# modules_to_not_convert = llm_int8_skip_modules\n",
|
|
"\n",
|
|
"# if not isinstance(modules_to_not_convert, list):\n",
|
|
"# modules_to_not_convert = [modules_to_not_convert]\n",
|
|
"\n",
|
|
"# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
|
|
"\n",
|
|
"# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
|
|
"\n",
|
|
"# if quantization_config.load_in_4bit and not supports_4bit:\n",
|
|
"# raise ValueError(\n",
|
|
"# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
|
|
"# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
|
|
"# )\n",
|
|
" \n",
|
|
"# if len(modules_to_not_convert) == 0:\n",
|
|
"# modules_to_not_convert = None\n",
|
|
"\n",
|
|
"# model = replace_with_bnb_linear(\n",
|
|
"# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
|
|
"# )\n",
|
|
"\n",
|
|
"# # training in 8-bit is only available in 0.37.0+\n",
|
|
"# model._is_kbit_training_enabled = version.parse(\n",
|
|
"# importlib_metadata.version(\"bitsandbytes\")\n",
|
|
"# ) >= version.parse(\"0.37.0\")\n",
|
|
"\n",
|
|
"# model.config.quantization_config = quantization_config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if bits == 4:\n",
|
|
" from accelerate.utils import CustomDtype\n",
|
|
" target_dtype = CustomDtype.INT4\n",
|
|
"elif bits == 8:\n",
|
|
" target_dtype = torch.int8\n",
|
|
"\n",
|
|
"if lora_dim > 0:\n",
|
|
" for param in model.parameters():\n",
|
|
" if param.ndim == 1:\n",
|
|
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
|
|
" param.data = param.data.to(torch.float32)\n",
|
|
" \n",
|
|
" class CastOutputToFloat(nn.Sequential):\n",
|
|
" def forward(self, x):\n",
|
|
" return super().forward(x).to(torch.float32)\n",
|
|
"\n",
|
|
" # model.lm_head = CastOutputToFloat(model.lm_head)\n",
|
|
" for i, lm_head in enumerate(model.lm_heads):\n",
|
|
" model.lm_heads[i] = CastOutputToFloat(lm_head)\n",
|
|
"\n",
|
|
" model = convert_linear_layer_to_lora(model, lora_module_name,\n",
|
|
" lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
|
|
" lora_dropout=lora_dropout)\n",
|
|
" if optimize_lora_params_only:\n",
|
|
" model = only_optimize_lora_parameters(model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"params_to_optimize = (\n",
|
|
" param for param in model.parameters() if param.requires_grad\n",
|
|
" )\n",
|
|
"\n",
|
|
"optimizer = optimizer_class(\n",
|
|
" params_to_optimize,\n",
|
|
" lr=learning_rate,\n",
|
|
" betas=(adam_beta1, adam_beta2),\n",
|
|
" weight_decay=weight_decay,\n",
|
|
" eps=adam_epsilon,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"opt_train = {\n",
|
|
" 'path': dataset_path,\n",
|
|
" 'mode': 'train',\n",
|
|
"}\n",
|
|
"\n",
|
|
"opt_val = {\n",
|
|
" 'path': dataset_path,\n",
|
|
" 'mode': 'valid',\n",
|
|
"}\n",
|
|
"\n",
|
|
"train_dataset = TtsDataset(opt_train)\n",
|
|
"validation_dataset = TtsDataset(opt_val)\n",
|
|
"\n",
|
|
"train_dataloader = torch.utils.data.DataLoader(\n",
|
|
" train_dataset,\n",
|
|
" batch_size=train_batch_size,\n",
|
|
" collate_fn=TtsCollater(),\n",
|
|
")\n",
|
|
"\n",
|
|
"validation_dataloader = torch.utils.data.DataLoader(\n",
|
|
" validation_dataset,\n",
|
|
" batch_size=eval_batch_size,\n",
|
|
" collate_fn=TtsCollater(),\n",
|
|
")\n",
|
|
"\n",
|
|
"criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n",
|
|
"\n",
|
|
"# Scheduler and math around the number of training steps.\n",
|
|
"overrode_max_train_steps = False\n",
|
|
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
|
|
"if max_train_steps is None:\n",
|
|
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
|
|
" overrode_max_train_steps = True\n",
|
|
"\n",
|
|
"lr_scheduler = get_scheduler(\n",
|
|
" lr_scheduler_type,\n",
|
|
" optimizer=optimizer,\n",
|
|
" num_warmup_steps=lr_warmup_steps * grad_accum,\n",
|
|
" num_training_steps=max_train_steps * grad_accum,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
|
|
" model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
|
|
")\n",
|
|
"accelerator.register_for_checkpointing(lr_scheduler)\n",
|
|
"\n",
|
|
"weight_dtype = torch.float32\n",
|
|
"if accelerator.mixed_precision == \"fp16\":\n",
|
|
" weight_dtype = torch.float16\n",
|
|
"elif accelerator.mixed_precision == \"bf16\":\n",
|
|
" weight_dtype = torch.bfloat16"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
|
|
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
|
|
"if overrode_max_train_steps:\n",
|
|
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
|
|
"# Afterwards we recalculate our number of training epochs\n",
|
|
"num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
|
|
"\n",
|
|
"# We need to initialize the trackers we use, and also store our configuration.\n",
|
|
"# The trackers initializes automatically on the main process.\n",
|
|
"if accelerator.is_main_process:\n",
|
|
" accelerator.init_trackers(\"bark_coarse\", config={})\n",
|
|
"\n",
|
|
"total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
|
|
"logger.info(\"***** Running training *****\")\n",
|
|
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
|
|
"logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
|
|
"logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
|
|
"logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
|
|
"logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
|
|
"logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
|
|
"logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
|
|
"global_step = 0\n",
|
|
"first_epoch = 0\n",
|
|
"\n",
|
|
"if resume_from_checkpoint:\n",
|
|
" if resume_from_checkpoint != \"latest\":\n",
|
|
" path = os.path.basename(resume_from_checkpoint)\n",
|
|
" else:\n",
|
|
" # Get the most recent checkpoint\n",
|
|
" dirs = os.listdir(output_dir)\n",
|
|
" dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
|
|
" dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
|
|
" path = dirs[-1]\n",
|
|
" accelerator.print(f\"Resuming from checkpoint {path}\")\n",
|
|
" accelerator.load_state(os.path.join(output_dir, path))\n",
|
|
" global_step = int(path.split(\"-\")[1])\n",
|
|
"\n",
|
|
" resume_global_step = global_step * grad_accum\n",
|
|
" first_epoch = resume_global_step // num_update_steps_per_epoch\n",
|
|
" resume_step = resume_global_step % num_update_steps_per_epoch\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if accelerator.is_main_process:\n",
|
|
" model.eval()\n",
|
|
" validation_loss = 0.0\n",
|
|
" num_batches = 0\n",
|
|
" num_samples = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for val_step, val_batch in enumerate(validation_dataloader):\n",
|
|
" # Similar to training, process the validation batch\n",
|
|
" fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n",
|
|
" fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n",
|
|
" fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n",
|
|
" fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n",
|
|
"\n",
|
|
" # Forward pass for validation\n",
|
|
" logits_7 = model(6, fine_tokens_input_7)\n",
|
|
" logits_8 = model(7, fine_tokens_input_8)\n",
|
|
"\n",
|
|
" # Calculate the validation loss\n",
|
|
" loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
|
|
" loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
|
|
"\n",
|
|
" loss = (loss_7 + loss_8) / 2\n",
|
|
" validation_loss += loss.item()\n",
|
|
" num_batches += 1\n",
|
|
" num_samples += val_batch['fine_tokens'].size(0)\n",
|
|
"\n",
|
|
" average_validation_loss = validation_loss / num_batches\n",
|
|
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
|
|
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Only show the progress bar once on each machine.\n",
|
|
"progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
|
|
"progress_bar.set_description(\"Steps\")\n",
|
|
"\n",
|
|
"for epoch in range(first_epoch, num_train_epochs):\n",
|
|
" model.train()\n",
|
|
" for step, batch in enumerate(train_dataloader):\n",
|
|
" # Skip steps until we reach the resumed step\n",
|
|
" if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
|
|
" if step % grad_accum == 0:\n",
|
|
" progress_bar.update(1)\n",
|
|
" continue\n",
|
|
"\n",
|
|
" with accelerator.accumulate(model):\n",
|
|
" fine_targets_7 = batch['fine_tokens'][:, :, 6]\n",
|
|
" fine_tokens_input_7 = torch.cat([batch['fine_tokens'][:, :, :6], torch.zeros_like(batch['fine_tokens'][:, :, 6:])], dim=2)\n",
|
|
" fine_targets_8 = batch['fine_tokens'][:, :, 7]\n",
|
|
" fine_tokens_input_8 = torch.cat([batch['fine_tokens'][:, :, :7], torch.zeros_like(batch['fine_tokens'][:, :, 7:])], dim=2)\n",
|
|
"\n",
|
|
" # Forward pass\n",
|
|
" logits_7 = model(6, fine_tokens_input_7)\n",
|
|
" logits_8 = model(7, fine_tokens_input_8)\n",
|
|
"\n",
|
|
" # Calculate the loss\n",
|
|
" loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
|
|
" loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
|
|
"\n",
|
|
" loss = (loss_7 + loss_8) / 2\n",
|
|
"\n",
|
|
" accelerator.backward(loss)\n",
|
|
" if accelerator.sync_gradients:\n",
|
|
" params_to_clip = (\n",
|
|
" param for param in model.parameters() if param.requires_grad\n",
|
|
" )\n",
|
|
" accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
|
|
" optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
|
|
" if accelerator.sync_gradients:\n",
|
|
" progress_bar.update(1)\n",
|
|
" global_step += 1\n",
|
|
"\n",
|
|
" if global_step % checkpointing_steps == 0:\n",
|
|
" if accelerator.is_main_process:\n",
|
|
" save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
|
|
" accelerator.save_state(save_path)\n",
|
|
" logger.info(f\"Saved state to {save_path}\")\n",
|
|
"\n",
|
|
" logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
|
|
" progress_bar.set_postfix(**logs)\n",
|
|
" accelerator.log(logs, step=global_step)\n",
|
|
"\n",
|
|
" if global_step >= max_train_steps:\n",
|
|
" break\n",
|
|
" \n",
|
|
" accelerator.wait_for_everyone()\n",
|
|
"\n",
|
|
"if accelerator.is_main_process:\n",
|
|
" if lora_dim > 0:\n",
|
|
" model = convert_lora_to_linear_layer(model)\n",
|
|
" # save model\n",
|
|
" accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
|
|
" \n",
|
|
" config = model.config.__dict__\n",
|
|
" # save config\n",
|
|
" with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
|
|
" json.dump(config, f, indent=2)\n",
|
|
"\n",
|
|
"accelerator.end_training()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Validation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if accelerator.is_main_process:\n",
|
|
" model.eval()\n",
|
|
" validation_loss = 0.0\n",
|
|
" num_batches = 0\n",
|
|
" num_samples = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for val_step, val_batch in enumerate(validation_dataloader):\n",
|
|
" # Similar to training, process the validation batch\n",
|
|
" fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n",
|
|
" fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n",
|
|
" fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n",
|
|
" fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n",
|
|
"\n",
|
|
" # Forward pass for validation\n",
|
|
" logits_7 = model(6, fine_tokens_input_7)\n",
|
|
" logits_8 = model(7, fine_tokens_input_8)\n",
|
|
"\n",
|
|
" # Calculate the validation loss\n",
|
|
" loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
|
|
" loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
|
|
"\n",
|
|
" loss = (loss_7 + loss_8) / 2\n",
|
|
" validation_loss += loss.item()\n",
|
|
" num_batches += 1\n",
|
|
" num_samples += val_batch['fine_tokens'].size(0)\n",
|
|
"\n",
|
|
" average_validation_loss = validation_loss / num_batches\n",
|
|
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
|
|
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|