mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-14 18:57:56 +01:00
900 lines
34 KiB
Plaintext
900 lines
34 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Imports"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import torch\n",
|
|
"import torch.nn as nn\n",
|
|
"import os\n",
|
|
"import re\n",
|
|
"import gc\n",
|
|
"import json\n",
|
|
"import math\n",
|
|
"import hashlib\n",
|
|
"import numpy as np\n",
|
|
"import logging\n",
|
|
"import torchaudio\n",
|
|
"from tqdm.auto import tqdm\n",
|
|
"import torch.nn.functional as F\n",
|
|
"from encodec.utils import convert_audio\n",
|
|
"from accelerate import Accelerator\n",
|
|
"from accelerate.utils import set_seed\n",
|
|
"from transformers import BertTokenizer\n",
|
|
"from huggingface_hub import hf_hub_download\n",
|
|
"from packaging import version\n",
|
|
"from diffusers.optimization import get_scheduler\n",
|
|
"\n",
|
|
"from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
|
|
"from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
|
|
"from bark.model import GPTConfig, GPT\n",
|
|
"from bark.model_fine import FineGPT, FineGPTConfig"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Training Args"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"train_batch_size = 8\n",
|
|
"eval_batch_size = 8\n",
|
|
"grad_accum = 2\n",
|
|
"ckpt_path = 'models/text_2.pt'\n",
|
|
"model_type = \"text\"\n",
|
|
"dataset_path = 'datasets/joe_biden_state_of_union/'\n",
|
|
"logging_dir = 'logs/'\n",
|
|
"log_with = 'wandb'\n",
|
|
"hubert_path = 'data/models/hubert/hubert.pt'\n",
|
|
"hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
|
|
"\n",
|
|
"output_dir = 'semantic_output/'\n",
|
|
"resume_from_checkpoint = None\n",
|
|
"\n",
|
|
"checkpointing_steps = 1000\n",
|
|
"\n",
|
|
"mixed_precision = 'bf16'\n",
|
|
"bits = 16 #4 4 and 8 bit are a work in progress\n",
|
|
"compute_dtype = torch.bfloat16\n",
|
|
"double_quant = True\n",
|
|
"quant_type = 'nf4'\n",
|
|
"\n",
|
|
"lora_dim = 64\n",
|
|
"lora_scaling = 1\n",
|
|
"lora_dropout = 0.1\n",
|
|
"lora_module_name = 'transformer.h'\n",
|
|
"optimize_lora_params_only = False\n",
|
|
"\n",
|
|
"learning_rate = 1e-4\n",
|
|
"scale_lr = False\n",
|
|
"use_8bit_adam = False\n",
|
|
"adam_beta1 = 0.9\n",
|
|
"adam_beta2 = 0.999\n",
|
|
"adam_epsilon = 1e-8\n",
|
|
"weight_decay = 0.01\n",
|
|
"\n",
|
|
"llm_int8_skip_modules = None\n",
|
|
"keep_in_fp32_modules = ['lm_head']\n",
|
|
"\n",
|
|
"lr_scheduler_type = 'linear'\n",
|
|
"lr_warmup_steps = 60\n",
|
|
"num_train_epochs = 5\n",
|
|
"max_train_steps = None\n",
|
|
"max_grad_norm = 1.0\n",
|
|
"\n",
|
|
"seed = 741"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Define Functions"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"CONTEXT_WINDOW_SIZE = 1024\n",
|
|
"\n",
|
|
"MAX_TEXT_LEN = 256\n",
|
|
"\n",
|
|
"SEMANTIC_RATE_HZ = 49.9\n",
|
|
"SEMANTIC_VOCAB_SIZE = 10_000\n",
|
|
"\n",
|
|
"TEXT_ENCODING_OFFSET = 10_048\n",
|
|
"SEMANTIC_PAD_TOKEN = 10_000\n",
|
|
"TEXT_PAD_TOKEN = 129_595\n",
|
|
"SEMANTIC_INFER_TOKEN = 129_599\n",
|
|
"\n",
|
|
"MAX_SEMANTIC_LEN = 511\n",
|
|
"\n",
|
|
"SAMPLE_RATE = 24_000\n",
|
|
"CHANNELS = 1\n",
|
|
"\n",
|
|
"logger = logging.getLogger(__name__)\n",
|
|
"\n",
|
|
"\n",
|
|
"USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
|
|
"\n",
|
|
"default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
|
|
"CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
|
|
"\n",
|
|
"\n",
|
|
"def _clear_cuda_cache():\n",
|
|
" if torch.cuda.is_available():\n",
|
|
" torch.cuda.empty_cache()\n",
|
|
" torch.cuda.synchronize()\n",
|
|
"\n",
|
|
"\n",
|
|
"def _md5(fname):\n",
|
|
" hash_md5 = hashlib.md5()\n",
|
|
" with open(fname, \"rb\") as f:\n",
|
|
" for chunk in iter(lambda: f.read(4096), b\"\"):\n",
|
|
" hash_md5.update(chunk)\n",
|
|
" return hash_md5.hexdigest()\n",
|
|
"\n",
|
|
"\n",
|
|
"def _download(from_hf_path, file_name, to_local_path):\n",
|
|
" to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
|
|
" path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
|
|
" os.makedirs(path, exist_ok=True)\n",
|
|
" hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
|
|
" os.replace(os.path.join(path, file_name), to_local_path)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _tokenize(tokenizer, text):\n",
|
|
" return tokenizer.encode(text, add_special_tokens=False)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _detokenize(tokenizer, enc_text):\n",
|
|
" return tokenizer.decode(enc_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"def _normalize_whitespace(text):\n",
|
|
" return re.sub(r\"\\s+\", \" \", text).strip()\n",
|
|
"\n",
|
|
"\n",
|
|
"REMOTE_MODEL_PATHS = {\n",
|
|
" \"text_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"text.pt\",\n",
|
|
" \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
|
|
" },\n",
|
|
" \"coarse_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"coarse.pt\",\n",
|
|
" \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
|
|
" },\n",
|
|
" \"fine_small\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"fine.pt\",\n",
|
|
" \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
|
|
" },\n",
|
|
" \"text\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"text_2.pt\",\n",
|
|
" \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
|
|
" },\n",
|
|
" \"coarse\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"coarse_2.pt\",\n",
|
|
" \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
|
|
" },\n",
|
|
" \"fine\": {\n",
|
|
" \"repo_id\": \"suno/bark\",\n",
|
|
" \"file_name\": \"fine_2.pt\",\n",
|
|
" \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
|
|
" },\n",
|
|
"}\n",
|
|
"\n",
|
|
"\n",
|
|
"def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
|
|
" if model_type == \"text\":\n",
|
|
" ConfigClass = GPTConfig\n",
|
|
" ModelClass = GPT\n",
|
|
" elif model_type == \"coarse\":\n",
|
|
" ConfigClass = GPTConfig\n",
|
|
" ModelClass = GPT\n",
|
|
" elif model_type == \"fine\":\n",
|
|
" ConfigClass = FineGPTConfig\n",
|
|
" ModelClass = FineGPT\n",
|
|
" else:\n",
|
|
" raise NotImplementedError()\n",
|
|
" model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
|
|
" model_info = REMOTE_MODEL_PATHS[model_key]\n",
|
|
" if ckpt_path in [None, '']:\n",
|
|
" ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
|
|
" if not os.path.exists(ckpt_path):\n",
|
|
" logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
|
|
" _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
|
|
" checkpoint = torch.load(ckpt_path, map_location=device)\n",
|
|
" # this is a hack\n",
|
|
" model_args = checkpoint[\"model_args\"]\n",
|
|
" if \"input_vocab_size\" not in model_args:\n",
|
|
" model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
|
|
" model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
|
|
" del model_args[\"vocab_size\"]\n",
|
|
" gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
|
|
" model = ModelClass(gptconf)\n",
|
|
" state_dict = checkpoint[\"model\"]\n",
|
|
" # fixup checkpoint\n",
|
|
" unwanted_prefix = \"_orig_mod.\"\n",
|
|
" for k, v in list(state_dict.items()):\n",
|
|
" if k.startswith(unwanted_prefix):\n",
|
|
" state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
|
|
" extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
|
|
" extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
|
|
" missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
|
|
" missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
|
|
" if len(extra_keys) != 0:\n",
|
|
" raise ValueError(f\"extra keys found: {extra_keys}\")\n",
|
|
" if len(missing_keys) != 0:\n",
|
|
" raise ValueError(f\"missing keys: {missing_keys}\")\n",
|
|
" model.load_state_dict(state_dict, strict=False)\n",
|
|
" n_params = model.get_num_params()\n",
|
|
" val_loss = checkpoint[\"best_val_loss\"].item()\n",
|
|
" print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
|
|
" del checkpoint, state_dict\n",
|
|
" _clear_cuda_cache()\n",
|
|
" if model_type == \"text\":\n",
|
|
" tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
|
|
" return model, tokenizer\n",
|
|
" return model\n",
|
|
"\n",
|
|
"\n",
|
|
"def load_filepaths_and_text(filename, split=\"|\"):\n",
|
|
" with open(filename, encoding='utf-8', errors='ignore') as f:\n",
|
|
" filepaths_and_text = [line.strip().split(split) for line in f]\n",
|
|
" base = os.path.dirname(filename)\n",
|
|
" for j in range(len(filepaths_and_text)):\n",
|
|
" filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
|
|
" return filepaths_and_text\n",
|
|
"\n",
|
|
"class TtsDataset(torch.utils.data.Dataset):\n",
|
|
" def __init__(self, opt):\n",
|
|
" self.path = os.path.dirname(opt['path'])\n",
|
|
" self.mode = opt['mode']\n",
|
|
" self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '_valid.txt'))\n",
|
|
" self.tokenizer = opt['tokenizer']\n",
|
|
"\n",
|
|
" def __getitem__(self, index):\n",
|
|
" audiopath_and_text = self.audiopaths_and_text[index]\n",
|
|
" audiopath, text = audiopath_and_text[0], audiopath_and_text[1]\n",
|
|
"\n",
|
|
" input_ids = np.array(_tokenize(self.tokenizer, text)) + TEXT_ENCODING_OFFSET\n",
|
|
" input_ids = torch.from_numpy(input_ids).long()\n",
|
|
" tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
|
|
" semantic_tokens = tokens['semantic']\n",
|
|
" semantic_tokens = torch.from_numpy(semantic_tokens).long()\n",
|
|
"\n",
|
|
" return input_ids, semantic_tokens\n",
|
|
"\n",
|
|
" def __len__(self):\n",
|
|
" return len(self.audiopaths_and_text)\n",
|
|
"\n",
|
|
"\n",
|
|
"class TtsCollater():\n",
|
|
" def __init__(self):\n",
|
|
" pass\n",
|
|
" def __call__(self, batch):\n",
|
|
" max_text_len = MAX_TEXT_LEN\n",
|
|
" max_semantic_tokens_len = MAX_SEMANTIC_LEN\n",
|
|
" texts = []\n",
|
|
" semantic_tokens = []\n",
|
|
"\n",
|
|
" for b in batch:\n",
|
|
" text, semantic_tokens_ = b\n",
|
|
" text = F.pad(text, (0, max_text_len-len(text)), value=TEXT_PAD_TOKEN)\n",
|
|
" semantic_history = torch.from_numpy(np.array([SEMANTIC_PAD_TOKEN] * 256))\n",
|
|
" text = torch.cat([text, semantic_history, torch.tensor([SEMANTIC_INFER_TOKEN])])\n",
|
|
" texts.append(text)\n",
|
|
" semantic_tokens_ = semantic_tokens_[:max_semantic_tokens_len]\n",
|
|
" semantic_tokens.append(F.pad(semantic_tokens_, (0, max_semantic_tokens_len-len(semantic_tokens_)), value=SEMANTIC_PAD_TOKEN))\n",
|
|
"\n",
|
|
" return {\n",
|
|
" 'input_ids': torch.stack(texts).contiguous(),\n",
|
|
" 'semantic_tokens': torch.stack(semantic_tokens).contiguous()\n",
|
|
" }\n",
|
|
" \n",
|
|
"\n",
|
|
"accelerator = Accelerator(\n",
|
|
" gradient_accumulation_steps=grad_accum,\n",
|
|
" mixed_precision=mixed_precision,\n",
|
|
" log_with=log_with,\n",
|
|
" logging_dir=logging_dir,\n",
|
|
")\n",
|
|
"device = accelerator.device\n",
|
|
"\n",
|
|
"os.makedirs(output_dir, exist_ok=True)\n",
|
|
"\n",
|
|
"set_seed(seed)"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Setup Dataset (only need to do this once)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
|
|
"\n",
|
|
"path = dataset_path\n",
|
|
"\n",
|
|
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
|
|
"from hubert.hubert_manager import HuBERTManager\n",
|
|
"hubert_manager = HuBERTManager()\n",
|
|
"from hubert.pre_kmeans_hubert import CustomHubert\n",
|
|
"from hubert.customtokenizer import CustomTokenizer\n",
|
|
"hubert_manager.make_sure_hubert_installed()\n",
|
|
"hubert_manager.make_sure_tokenizer_installed()\n",
|
|
"\n",
|
|
"# Load the HuBERT model\n",
|
|
"hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
|
|
"hubert_model.eval()\n",
|
|
"for param in hubert_model.parameters():\n",
|
|
" param.requires_grad = False\n",
|
|
"\n",
|
|
"# Load the CustomTokenizer model\n",
|
|
"hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
|
|
"\n",
|
|
"from bark.generation import load_codec_model\n",
|
|
"codec_model = load_codec_model(use_gpu=True)\n",
|
|
"codec_model.eval()\n",
|
|
"for param in codec_model.parameters():\n",
|
|
" param.requires_grad = False\n",
|
|
"\n",
|
|
"\n",
|
|
"def get_duration(wav, sr):\n",
|
|
" return wav.shape[1] / sr\n",
|
|
"\n",
|
|
"valid_lines_train = []\n",
|
|
"# convert wavs to semantic tokens\n",
|
|
"for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
|
|
" wav, sr = torchaudio.load(wav_path)\n",
|
|
" if not get_duration(wav, sr) > max_duration_sec:\n",
|
|
" valid_lines_train.append((wav_path, txt))\n",
|
|
" wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
|
|
"\n",
|
|
" semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
|
|
" semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
|
|
"\n",
|
|
" # save semantic tokens\n",
|
|
" os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
|
|
" semantic_tokens = semantic_tokens.cpu().numpy()\n",
|
|
"\n",
|
|
" # Extract discrete codes from EnCodec\n",
|
|
" with torch.no_grad():\n",
|
|
" encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
|
|
" codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
|
|
"\n",
|
|
" # move codes to cpu\n",
|
|
" codes = codes.cpu().numpy()\n",
|
|
"\n",
|
|
" # save tokens\n",
|
|
" np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
|
|
"\n",
|
|
"# rewrite train.txt with valid lines\n",
|
|
"with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
|
|
" for wav_path, txt in valid_lines_train:\n",
|
|
" wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
|
|
" f.write(f'{wav_path}|{txt}\\n')\n",
|
|
"\n",
|
|
"valid_lines_valid = []\n",
|
|
"for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
|
|
" wav, sr = torchaudio.load(wav_path)\n",
|
|
" if not get_duration(wav, sr) > max_duration_sec:\n",
|
|
" valid_lines_valid.append((wav_path, txt))\n",
|
|
" wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
|
|
"\n",
|
|
" semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
|
|
" semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
|
|
"\n",
|
|
" # save semantic tokens\n",
|
|
" os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
|
|
" semantic_tokens = semantic_tokens.cpu().numpy()\n",
|
|
" \n",
|
|
" # Extract discrete codes from EnCodec\n",
|
|
" with torch.no_grad():\n",
|
|
" encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
|
|
" codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
|
|
"\n",
|
|
" # move codes to cpu\n",
|
|
" codes = codes.cpu().numpy()\n",
|
|
"\n",
|
|
" # save tokens\n",
|
|
" np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
|
|
"\n",
|
|
"# rewrite valid.txt with valid lines\n",
|
|
"with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
|
|
" for wav_path, txt in valid_lines_valid:\n",
|
|
" wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
|
|
" f.write(f'{wav_path}|{txt}\\n')\n",
|
|
"\n",
|
|
"del hubert_model\n",
|
|
"del hubert_tokenizer\n",
|
|
"del codec_model\n",
|
|
"gc.collect()\n",
|
|
"torch.cuda.empty_cache()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Setup"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model, tokenizer = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if scale_lr:\n",
|
|
" learning_rate = (\n",
|
|
" learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
|
|
" )\n",
|
|
"\n",
|
|
"if use_8bit_adam:\n",
|
|
" try:\n",
|
|
" import bitsandbytes as bnb\n",
|
|
" except ImportError:\n",
|
|
" raise ImportError(\n",
|
|
" \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
|
|
" )\n",
|
|
"\n",
|
|
" optimizer_class = bnb.optim.AdamW8bit\n",
|
|
"else:\n",
|
|
" optimizer_class = torch.optim.AdamW"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"quantization_config=BitsAndBytesConfig(\n",
|
|
" load_in_4bit=bits == 4,\n",
|
|
" load_in_8bit=bits == 8,\n",
|
|
" llm_int8_threshold=6.0,\n",
|
|
" llm_int8_has_fp16_weight=False,\n",
|
|
" bnb_4bit_compute_dtype=compute_dtype,\n",
|
|
" bnb_4bit_use_double_quant=double_quant,\n",
|
|
" bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
|
|
")\n",
|
|
"\n",
|
|
"# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
|
|
"# if quantization_config.load_in_8bit:\n",
|
|
"# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
|
|
"# elif quantization_config.load_in_4bit:\n",
|
|
"# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
|
|
"\n",
|
|
"# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
|
|
"# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
|
|
"# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
|
|
"# else:\n",
|
|
"# modules_to_not_convert = llm_int8_skip_modules\n",
|
|
"\n",
|
|
"# if not isinstance(modules_to_not_convert, list):\n",
|
|
"# modules_to_not_convert = [modules_to_not_convert]\n",
|
|
"\n",
|
|
"# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
|
|
"\n",
|
|
"# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
|
|
"\n",
|
|
"# if quantization_config.load_in_4bit and not supports_4bit:\n",
|
|
"# raise ValueError(\n",
|
|
"# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
|
|
"# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
|
|
"# )\n",
|
|
" \n",
|
|
"# if len(modules_to_not_convert) == 0:\n",
|
|
"# modules_to_not_convert = None\n",
|
|
"\n",
|
|
"# model = replace_with_bnb_linear(\n",
|
|
"# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
|
|
"# )\n",
|
|
"\n",
|
|
"# # training in 8-bit is only available in 0.37.0+\n",
|
|
"# model._is_kbit_training_enabled = version.parse(\n",
|
|
"# importlib_metadata.version(\"bitsandbytes\")\n",
|
|
"# ) >= version.parse(\"0.37.0\")\n",
|
|
"\n",
|
|
"# model.config.quantization_config = quantization_config"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if bits == 4:\n",
|
|
" from accelerate.utils import CustomDtype\n",
|
|
" target_dtype = CustomDtype.INT4\n",
|
|
"elif bits == 8:\n",
|
|
" target_dtype = torch.int8\n",
|
|
"\n",
|
|
"if lora_dim > 0:\n",
|
|
" for param in model.parameters():\n",
|
|
" if param.ndim == 1:\n",
|
|
" # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
|
|
" param.data = param.data.to(torch.float32)\n",
|
|
" \n",
|
|
" class CastOutputToFloat(nn.Sequential):\n",
|
|
" def forward(self, x):\n",
|
|
" return super().forward(x).to(torch.float32)\n",
|
|
"\n",
|
|
" model.lm_head = CastOutputToFloat(model.lm_head)\n",
|
|
"\n",
|
|
" model = convert_linear_layer_to_lora(model, lora_module_name,\n",
|
|
" lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
|
|
" lora_dropout=lora_dropout)\n",
|
|
" if optimize_lora_params_only:\n",
|
|
" model = only_optimize_lora_parameters(model)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"params_to_optimize = (\n",
|
|
" param for param in model.parameters() if param.requires_grad\n",
|
|
" )\n",
|
|
"\n",
|
|
"optimizer = optimizer_class(\n",
|
|
" params_to_optimize,\n",
|
|
" lr=learning_rate,\n",
|
|
" betas=(adam_beta1, adam_beta2),\n",
|
|
" weight_decay=weight_decay,\n",
|
|
" eps=adam_epsilon,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"opt_train = {\n",
|
|
" 'path': dataset_path,\n",
|
|
" 'tokenizer': tokenizer,\n",
|
|
" 'mode': 'train',\n",
|
|
"}\n",
|
|
"\n",
|
|
"opt_val = {\n",
|
|
" 'path': dataset_path,\n",
|
|
" 'tokenizer': tokenizer,\n",
|
|
" 'mode': 'valid',\n",
|
|
"}\n",
|
|
"\n",
|
|
"train_dataset = TtsDataset(opt_train)\n",
|
|
"validation_dataset = TtsDataset(opt_val)\n",
|
|
"\n",
|
|
"train_dataloader = torch.utils.data.DataLoader(\n",
|
|
" train_dataset,\n",
|
|
" batch_size=train_batch_size,\n",
|
|
" collate_fn=TtsCollater(),\n",
|
|
")\n",
|
|
"\n",
|
|
"validation_dataloader = torch.utils.data.DataLoader(\n",
|
|
" validation_dataset,\n",
|
|
" batch_size=eval_batch_size,\n",
|
|
" collate_fn=TtsCollater(),\n",
|
|
")\n",
|
|
"\n",
|
|
"criterion = torch.nn.CrossEntropyLoss() #ignore_index=SEMANTIC_PAD_TOKEN)\n",
|
|
"\n",
|
|
"# Scheduler and math around the number of training steps.\n",
|
|
"overrode_max_train_steps = False\n",
|
|
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
|
|
"if max_train_steps is None:\n",
|
|
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
|
|
" overrode_max_train_steps = True\n",
|
|
"\n",
|
|
"lr_scheduler = get_scheduler(\n",
|
|
" lr_scheduler_type,\n",
|
|
" optimizer=optimizer,\n",
|
|
" num_warmup_steps=lr_warmup_steps * grad_accum,\n",
|
|
" num_training_steps=max_train_steps * grad_accum,\n",
|
|
")"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
|
|
" model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
|
|
")\n",
|
|
"accelerator.register_for_checkpointing(lr_scheduler)\n",
|
|
"\n",
|
|
"weight_dtype = torch.float32\n",
|
|
"if accelerator.mixed_precision == \"fp16\":\n",
|
|
" weight_dtype = torch.float16\n",
|
|
"elif accelerator.mixed_precision == \"bf16\":\n",
|
|
" weight_dtype = torch.bfloat16"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
|
|
"num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
|
|
"if overrode_max_train_steps:\n",
|
|
" max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
|
|
"# Afterwards we recalculate our number of training epochs\n",
|
|
"num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
|
|
"\n",
|
|
"# We need to initialize the trackers we use, and also store our configuration.\n",
|
|
"# The trackers initializes automatically on the main process.\n",
|
|
"if accelerator.is_main_process:\n",
|
|
" accelerator.init_trackers(\"bark_semantic\", config={})\n",
|
|
"\n",
|
|
"# Train!\n",
|
|
"total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
|
|
"logger.info(\"***** Running training *****\")\n",
|
|
"logger.info(f\" Num examples = {len(train_dataset)}\")\n",
|
|
"logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
|
|
"logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
|
|
"logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
|
|
"logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
|
|
"logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
|
|
"logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
|
|
"global_step = 0\n",
|
|
"first_epoch = 0\n",
|
|
"\n",
|
|
"if resume_from_checkpoint:\n",
|
|
" if resume_from_checkpoint != \"latest\":\n",
|
|
" path = os.path.basename(resume_from_checkpoint)\n",
|
|
" else:\n",
|
|
" # Get the most recent checkpoint\n",
|
|
" dirs = os.listdir(output_dir)\n",
|
|
" dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
|
|
" dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
|
|
" path = dirs[-1]\n",
|
|
" accelerator.print(f\"Resuming from checkpoint {path}\")\n",
|
|
" accelerator.load_state(os.path.join(output_dir, path))\n",
|
|
" global_step = int(path.split(\"-\")[1])\n",
|
|
"\n",
|
|
" resume_global_step = global_step * grad_accum\n",
|
|
" first_epoch = resume_global_step // num_update_steps_per_epoch\n",
|
|
" resume_step = resume_global_step % num_update_steps_per_epoch\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if accelerator.is_main_process:\n",
|
|
" model.eval()\n",
|
|
" validation_loss = 0.0\n",
|
|
" num_batches = 0\n",
|
|
" num_samples = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for val_step, val_batch in enumerate(validation_dataloader):\n",
|
|
" # Similar to training, process the validation batch\n",
|
|
" val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
|
|
" val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
|
|
" val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
|
|
"\n",
|
|
" # Forward pass for validation\n",
|
|
" val_logits = model(val_inputs, training=True)\n",
|
|
" val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
|
|
"\n",
|
|
" # Calculate the validation loss\n",
|
|
" val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
|
|
" validation_loss += val_loss.item()\n",
|
|
" num_batches += 1\n",
|
|
" num_samples += val_batch['input_ids'].size(0)\n",
|
|
"\n",
|
|
" average_validation_loss = validation_loss / num_batches\n",
|
|
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
|
|
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Training"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Only show the progress bar once on each machine.\n",
|
|
"progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
|
|
"progress_bar.set_description(\"Steps\")\n",
|
|
"\n",
|
|
"for epoch in range(first_epoch, num_train_epochs):\n",
|
|
" model.train()\n",
|
|
" for step, batch in enumerate(train_dataloader):\n",
|
|
" # Skip steps until we reach the resumed step\n",
|
|
" if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
|
|
" if step % grad_accum == 0:\n",
|
|
" progress_bar.update(1)\n",
|
|
" continue\n",
|
|
"\n",
|
|
" with accelerator.accumulate(model):\n",
|
|
" targets = batch['semantic_tokens'][:, 1:].contiguous()\n",
|
|
" \n",
|
|
" # Remove the last semantic token from the inputs since there is no target for it.\n",
|
|
" semantic_inputs = batch['semantic_tokens'][:, :-1]\n",
|
|
"\n",
|
|
" # Combine the text and semantic tokens and feed them into the model.\n",
|
|
" inputs = torch.cat([batch['input_ids'], semantic_inputs], dim=1)\n",
|
|
" logits = model(inputs, training=True)\n",
|
|
"\n",
|
|
" # We're only interested in the logits for the semantic tokens, so we ignore the logits for the input text tokens.\n",
|
|
" semantic_logits = logits[:, batch['input_ids'].size(1):].contiguous()\n",
|
|
"\n",
|
|
" # Compute the loss.\n",
|
|
" loss = criterion(semantic_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n",
|
|
"\n",
|
|
" accelerator.backward(loss)\n",
|
|
" if accelerator.sync_gradients:\n",
|
|
" params_to_clip = (\n",
|
|
" param for param in model.parameters() if param.requires_grad\n",
|
|
" )\n",
|
|
" accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
|
|
" optimizer.step()\n",
|
|
" lr_scheduler.step()\n",
|
|
" optimizer.zero_grad()\n",
|
|
"\n",
|
|
" # Checks if the accelerator has performed an optimization step behind the scenes\n",
|
|
" if accelerator.sync_gradients:\n",
|
|
" progress_bar.update(1)\n",
|
|
" global_step += 1\n",
|
|
"\n",
|
|
" if global_step % checkpointing_steps == 0:\n",
|
|
" if accelerator.is_main_process:\n",
|
|
" save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
|
|
" accelerator.save_state(save_path)\n",
|
|
" logger.info(f\"Saved state to {save_path}\")\n",
|
|
"\n",
|
|
" logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
|
|
" progress_bar.set_postfix(**logs)\n",
|
|
" accelerator.log(logs, step=global_step)\n",
|
|
"\n",
|
|
" if global_step >= max_train_steps:\n",
|
|
" break\n",
|
|
" \n",
|
|
" accelerator.wait_for_everyone()\n",
|
|
"\n",
|
|
"if accelerator.is_main_process:\n",
|
|
" if lora_dim > 0:\n",
|
|
" model = convert_lora_to_linear_layer(model)\n",
|
|
" # save model\n",
|
|
" accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
|
|
"\n",
|
|
" config = model.config.__dict__\n",
|
|
" # save config\n",
|
|
" with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
|
|
" json.dump(config, f, indent=2)\n",
|
|
"\n",
|
|
"accelerator.end_training()"
|
|
]
|
|
},
|
|
{
|
|
"attachments": {},
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Validation"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"if accelerator.is_main_process:\n",
|
|
" model.eval()\n",
|
|
" validation_loss = 0.0\n",
|
|
" num_batches = 0\n",
|
|
" num_samples = 0\n",
|
|
" with torch.no_grad():\n",
|
|
" for val_step, val_batch in enumerate(validation_dataloader):\n",
|
|
" # Similar to training, process the validation batch\n",
|
|
" val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
|
|
" val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
|
|
" val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
|
|
"\n",
|
|
" # Forward pass for validation\n",
|
|
" val_logits = model(val_inputs, training=True)\n",
|
|
" val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
|
|
"\n",
|
|
" # Calculate the validation loss\n",
|
|
" val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
|
|
" validation_loss += val_loss.item()\n",
|
|
" num_batches += 1\n",
|
|
" num_samples += val_batch['input_ids'].size(0)\n",
|
|
"\n",
|
|
" average_validation_loss = validation_loss / num_batches\n",
|
|
" logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
|
|
" print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.10.8"
|
|
},
|
|
"orig_nbformat": 4
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|