diff --git a/.gitignore b/.gitignore index 1cf5bdf..f52000b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ __pycache__/ *.wav _temp/ -models/ \ No newline at end of file +models/ +output.npz \ No newline at end of file diff --git a/README.md b/README.md index 9f16492..b53476d 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,9 @@ You will get the best results by making generations with your cloned voice until - [BARK text to speech @ SERP AI](https://serp.ai/tools/bark-text-to-speech-ai-voice-clone-app/) +# Shoutouts +- Huge shoutout to [gitmylo](https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer/) for the solution to the semantic token generation for better voice clones and finetunes + ------------------------------------------------------------------- # Original README.md diff --git a/bark/generation.py b/bark/generation.py index c9d4317..74fcc0d 100644 --- a/bark/generation.py +++ b/bark/generation.py @@ -60,12 +60,12 @@ CUR_PATH = os.path.dirname(os.path.abspath(__file__)) default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache") -CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0") +CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "serp", "bark_v0") -USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False) -GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False) -OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False) +USE_SMALL_MODELS = os.environ.get("SERP_USE_SMALL_MODELS", False) +GLOBAL_ENABLE_MPS = os.environ.get("SERP_ENABLE_MPS", False) +OFFLOAD_CPU = os.environ.get("SERP_OFFLOAD_CPU", False) REMOTE_MODEL_PATHS = { diff --git a/bark/model.py b/bark/model.py index 457b49e..b87e534 100644 --- a/bark/model.py +++ b/bark/model.py @@ -8,6 +8,8 @@ from dataclasses import dataclass import torch import torch.nn as nn from torch.nn import functional as F +from einops import rearrange, repeat, reduce +SEMANTIC_PAD_TOKEN = 10_000 class LayerNorm(nn.Module): """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ @@ -165,7 +167,7 @@ class GPT(nn.Module): n_params -= self.transformer.wpe.weight.numel() return n_params - def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False): + def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None): device = idx.device b, t = idx.size() if past_kv is not None: @@ -212,6 +214,21 @@ class GPT(nn.Module): x = self.transformer.ln_f(x) + + if labels is not None: + logits = self.lm_head(x) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = nn.CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.output_vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + return logits, loss + # inference-time mini-optimization: only forward the lm_head on the very last position logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim diff --git a/clone_voice.ipynb b/clone_voice.ipynb index 7ab0dbe..eb2495c 100644 --- a/clone_voice.ipynb +++ b/clone_voice.ipynb @@ -12,7 +12,39 @@ "import torchaudio\n", "import torch\n", "\n", - "model = load_codec_model(use_gpu=True)" + "device = 'cuda' # or 'cpu'\n", + "model = load_codec_model(use_gpu=True if device == 'cuda' else False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n", + "from hubert.hubert_manager import HuBERTManager\n", + "hubert_manager = HuBERTManager()\n", + "hubert_manager.make_sure_hubert_installed()\n", + "hubert_manager.make_sure_tokenizer_installed()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer \n", + "# Load HuBERT for semantic tokens\n", + "from hubert.pre_kmeans_hubert import CustomHubert\n", + "from hubert.customtokenizer import CustomTokenizer\n", + "\n", + "# Load the HuBERT model\n", + "hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(device)\n", + "\n", + "# Load the CustomTokenizer model\n", + "tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(device) # Automatically uses the right layers" ] }, { @@ -22,11 +54,20 @@ "outputs": [], "source": [ "# Load and pre-process the audio waveform\n", - "audio_filepath = 'audio.wav' # the audio you want to clone (will get truncated so 5-10 seconds is probably fine, existing samples that I checked are around 7 seconds)\n", - "device = 'cuda' # or 'cpu'\n", + "audio_filepath = 'audio.wav' # the audio you want to clone (under 13 seconds)\n", "wav, sr = torchaudio.load(audio_filepath)\n", "wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n", - "wav = wav.unsqueeze(0).to(device)" + "wav = wav.to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)\n", + "semantic_tokens = tokenizer.get_token(semantic_vectors)" ] }, { @@ -37,31 +78,10 @@ "source": [ "# Extract discrete codes from EnCodec\n", "with torch.no_grad():\n", - " encoded_frames = model.encode(wav)\n", + " encoded_frames = model.encode(wav.unsqueeze(0))\n", "codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "text = \"Transcription of the audio you are cloning\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# get seconds of audio\n", - "seconds = wav.shape[-1] / model.sample_rate\n", - "# generate semantic tokens\n", - "semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) # not 100% sure on this part" - ] - }, { "cell_type": "code", "execution_count": null, @@ -69,7 +89,9 @@ "outputs": [], "source": [ "# move codes to cpu\n", - "codes = codes.cpu().numpy()" + "codes = codes.cpu().numpy()\n", + "# move semantic tokens to cpu\n", + "semantic_tokens = semantic_tokens.cpu().numpy()" ] }, { @@ -121,10 +143,7 @@ "\n", "# Enter your prompt and speaker here\n", "text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n", - "voice_name = \"speaker_0\" # use your custom voice name here if you have one\n", - "\n", - "# load the tokenizer\n", - "tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")" + "voice_name = \"output\" # use your custom voice name here if you have one" ] }, { diff --git a/hubert/__init__.py b/hubert/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/hubert/customtokenizer.py b/hubert/customtokenizer.py new file mode 100644 index 0000000..7f807d3 --- /dev/null +++ b/hubert/customtokenizer.py @@ -0,0 +1,184 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +import json +import os.path +from zipfile import ZipFile + +import numpy +import torch +from torch import nn, optim +from torch.serialization import MAP_LOCATION + + +class CustomTokenizer(nn.Module): + def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0): + super(CustomTokenizer, self).__init__() + next_size = input_size + if version == 0: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + next_size = hidden_size + if version == 1: + self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True) + self.intermediate = nn.Linear(hidden_size, 4096) + next_size = 4096 + + self.fc = nn.Linear(next_size, output_size) + self.softmax = nn.LogSoftmax(dim=1) + self.optimizer: optim.Optimizer = None + self.lossfunc = nn.CrossEntropyLoss() + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + def forward(self, x): + x, _ = self.lstm(x) + if self.version == 1: + x = self.intermediate(x) + x = self.fc(x) + x = self.softmax(x) + return x + + @torch.no_grad() + def get_token(self, x): + """ + Used to get the token for the first + :param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model. + :return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model. + """ + return torch.argmax(self(x), dim=1) + + def prepare_training(self): + self.optimizer = optim.Adam(self.parameters(), 0.001) + + def train_step(self, x_train, y_train, log_loss=False): + # y_train = y_train[:-1] + # y_train = y_train[1:] + + optimizer = self.optimizer + lossfunc = self.lossfunc + # Zero the gradients + self.zero_grad() + + # Forward pass + y_pred = self(x_train) + + y_train_len = len(y_train) + y_pred_len = y_pred.shape[0] + + if y_train_len > y_pred_len: + diff = y_train_len - y_pred_len + y_train = y_train[diff:] + elif y_train_len < y_pred_len: + diff = y_pred_len - y_train_len + y_pred = y_pred[:-diff, :] + + y_train_hot = torch.zeros(len(y_train), self.output_size) + y_train_hot[range(len(y_train)), y_train] = 1 + y_train_hot = y_train_hot.to('cuda') + + # Calculate the loss + loss = lossfunc(y_pred, y_train_hot) + + # Print loss + if log_loss: + print('Loss', loss.item()) + + # Backward pass + loss.backward() + + # Update the weights + optimizer.step() + + def save(self, path): + info_path = os.path.basename(path) + '/.info' + torch.save(self.state_dict(), path) + data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version) + with ZipFile(path, 'a') as model_zip: + model_zip.writestr(info_path, data_from_model.save()) + model_zip.close() + + @staticmethod + def load_from_checkpoint(path, map_location: MAP_LOCATION = None): + old = True + with ZipFile(path) as model_zip: + filesMatch = [file for file in model_zip.namelist() if file.endswith('/.info')] + file = filesMatch[0] if filesMatch else None + if file: + old = False + data_from_model = Data.load(model_zip.read(file).decode('utf-8')) + model_zip.close() + if old: + model = CustomTokenizer() + else: + model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version) + model.load_state_dict(torch.load(path, map_location)) + return model + + + +class Data: + input_size: int + hidden_size: int + output_size: int + version: int + + def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0): + self.input_size = input_size + self.hidden_size = hidden_size + self.output_size = output_size + self.version = version + + @staticmethod + def load(string): + data = json.loads(string) + return Data(data['input_size'], data['hidden_size'], data['output_size'], data['version']) + + def save(self): + data = { + 'input_size': self.input_size, + 'hidden_size': self.hidden_size, + 'output_size': self.output_size, + 'version': self.version, + } + return json.dumps(data) + + +def auto_train(data_path, save_path='model.pth', load_model: str | None = None, save_epochs=1): + data_x, data_y = [], [] + + if load_model and os.path.isfile(load_model): + print('Loading model from', load_model) + model_training = CustomTokenizer.load_from_checkpoint(load_model, 'cuda') + else: + print('Creating new model.') + model_training = CustomTokenizer(version=1).to('cuda') # Settings for the model to run without lstm + save_path = os.path.join(data_path, save_path) + base_save_path = '.'.join(save_path.split('.')[:-1]) + + sem_string = '_semantic.npy' + feat_string = '_semantic_features.npy' + + ready = os.path.join(data_path, 'ready') + for input_file in os.listdir(ready): + full_path = os.path.join(ready, input_file) + if input_file.endswith(sem_string): + data_y.append(numpy.load(full_path)) + elif input_file.endswith(feat_string): + data_x.append(numpy.load(full_path)) + model_training.prepare_training() + + epoch = 1 + + while 1: + for i in range(save_epochs): + j = 0 + for x, y in zip(data_x, data_y): + model_training.train_step(torch.tensor(x).to('cuda'), torch.tensor(y).to('cuda'), j % 50 == 0) # Print loss every 50 steps + j += 1 + save_p = save_path + save_p_2 = f'{base_save_path}_epoch_{epoch}.pth' + model_training.save(save_p) + model_training.save(save_p_2) + print(f'Epoch {epoch} completed') + epoch += 1 diff --git a/hubert/hubert_manager.py b/hubert/hubert_manager.py new file mode 100644 index 0000000..f843c95 --- /dev/null +++ b/hubert/hubert_manager.py @@ -0,0 +1,35 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +import os.path +import shutil +import urllib.request + +import huggingface_hub + + +class HuBERTManager: + @staticmethod + def make_sure_hubert_installed(download_url: str = 'https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt', file_name: str = 'hubert.pt'): + install_dir = os.path.join('data', 'models', 'hubert') + if not os.path.isdir(install_dir): + os.makedirs(install_dir, exist_ok=True) + install_file = os.path.join(install_dir, file_name) + if not os.path.isfile(install_file): + print('Downloading HuBERT base model') + urllib.request.urlretrieve(download_url, install_file) + print('Downloaded HuBERT') + return install_file + + + @staticmethod + def make_sure_tokenizer_installed(model: str = 'quantifier_hubert_base_ls960_14.pth', repo: str = 'GitMylo/bark-voice-cloning', local_file: str = 'tokenizer.pth'): + install_dir = os.path.join('data', 'models', 'hubert') + if not os.path.isdir(install_dir): + os.makedirs(install_dir, exist_ok=True) + install_file = os.path.join(install_dir, local_file) + if not os.path.isfile(install_file): + print('Downloading HuBERT custom tokenizer') + huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False) + shutil.move(os.path.join(install_dir, model), install_file) + print('Downloaded tokenizer') + return install_file diff --git a/hubert/pre_kmeans_hubert.py b/hubert/pre_kmeans_hubert.py new file mode 100644 index 0000000..93f82fe --- /dev/null +++ b/hubert/pre_kmeans_hubert.py @@ -0,0 +1,94 @@ +# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer + +from pathlib import Path + +import torch +from torch import nn +from einops import pack, unpack + +import joblib + +import fairseq + +from torchaudio.functional import resample + +from audiolm_pytorch.utils import curtail_to_multiple + +import logging +logging.root.setLevel(logging.ERROR) + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +class CustomHubert(nn.Module): + """ + checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert + or you can train your own + """ + + def __init__( + self, + checkpoint_path, + target_sample_hz=16000, + seq_len_multiple_of=None, + output_layer=9 + ): + super().__init__() + self.target_sample_hz = target_sample_hz + self.seq_len_multiple_of = seq_len_multiple_of + self.output_layer = output_layer + + model_path = Path(checkpoint_path) + + assert model_path.exists(), f'path {checkpoint_path} does not exist' + + checkpoint = torch.load(checkpoint_path) + load_model_input = {checkpoint_path: checkpoint} + model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input) + + self.model = model[0] + self.model.eval() + + @property + def groups(self): + return 1 + + @torch.no_grad() + def forward( + self, + wav_input, + flatten=True, + input_sample_hz=None + ): + device = wav_input.device + + if exists(input_sample_hz): + wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz) + + if exists(self.seq_len_multiple_of): + wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of) + + embed = self.model( + wav_input, + features_only=True, + mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code + output_layer=self.output_layer + ) + + embed, packed_shape = pack([embed['x']], '* d') + + # codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy()) + + codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long() + + if flatten: + return codebook_indices + + codebook_indices, = unpack(codebook_indices, packed_shape, '*') + return codebook_indices diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/bitsandbytes.py b/utils/bitsandbytes.py new file mode 100644 index 0000000..6264b74 --- /dev/null +++ b/utils/bitsandbytes.py @@ -0,0 +1,508 @@ +# From https://github.com/huggingface/transformers/blob/e45e756d22206ca8fa9fb057c8c3d8fa79bf81c6/src/transformers/utils/bitsandbytes.py + +import warnings +import sys +import importlib.util +from copy import deepcopy +import copy +import json +import os +from dataclasses import dataclass + +from typing import Any, Tuple, Union, Dict + +from packaging import version + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]: + # Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version + package_exists = importlib.util.find_spec(pkg_name) is not None + package_version = "N/A" + if package_exists: + try: + package_version = importlib_metadata.version(pkg_name) + package_exists = True + except importlib_metadata.PackageNotFoundError: + package_exists = False + if return_version: + return package_exists, package_version + else: + return package_exists + +_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True) +_bitsandbytes_available = _is_package_available("bitsandbytes") +_torch_available, _torch_version = _is_package_available("torch", return_version=True) + +def is_accelerate_available(check_partial_state=False): + if check_partial_state: + return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.19.0") + return _accelerate_available + +def is_bitsandbytes_available(): + return _bitsandbytes_available + +def is_torch_available(): + return _torch_available + +if is_bitsandbytes_available(): + import bitsandbytes as bnb + import torch + import torch.nn as nn + +if is_accelerate_available(): + from accelerate import init_empty_weights + from accelerate.utils import find_tied_parameters + + +def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The + function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the + class `Int8Params` from `bitsandbytes`. + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + fp16_statistics (`torch.HalfTensor`, *optional*): + The list of fp16 statistics to set on the module, used for serialization. + """ + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + is_4bit = False + is_8bit = False + if is_buffer or not is_bitsandbytes_available(): + is_8bit = False + is_4bit = False + else: + is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit) + is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params) + + if is_8bit or is_4bit: + param = module._parameters[tensor_name] + if param.device.type != "cuda": + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to("cpu") + if value.dtype == torch.int8: + is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse( + "0.37.2" + ) + if not is_8bit_serializable: + raise ValueError( + "Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. " + "Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`." + ) + else: + new_value = torch.tensor(value, device="cpu") + + kwargs = old_value.__dict__ + if is_8bit: + new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device) + elif is_4bit: + new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module.weight, "SCB", fp16_statistics.to(device)) + + else: + if value is None: + new_value = old_value.to(device) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + + if is_buffer: + module._buffers[tensor_name] = new_value + else: + new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad) + module._parameters[tensor_name] = new_value + + +def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None): + """ + A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes` + library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8(): + 8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA + version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/ + bitsandbytes` + + The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should + be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no + CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a + matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16 + (0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no + predictive degradation is possible for very large models (>=176B parameters). + + Parameters: + model (`torch.nn.Module`): + Input model or `torch.nn.Module` as the function is run recursively. + modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`): + Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision + for numerical stability reasons. + current_key_name (`List[`str`]`, *optional*): + An array to track the current key of the recursion. This is used to check whether the current key (part of + it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or + `disk`). + """ + modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert + for name, module in model.named_children(): + if current_key_name is None: + current_key_name = [] + + if isinstance(module, nn.Linear) and name not in modules_to_not_convert: + # Check if the current key is not in the `modules_to_not_convert` + if not any(key in ".".join(current_key_name) for key in modules_to_not_convert): + with init_empty_weights(): + if quantization_config.quantization_method() == "llm_int8": + model._modules[name] = bnb.nn.Linear8bitLt( + module.in_features, + module.out_features, + module.bias is not None, + has_fp16_weights=quantization_config.llm_int8_has_fp16_weight, + threshold=quantization_config.llm_int8_threshold, + ) + else: + if ( + quantization_config.llm_int8_skip_modules is not None + and name in quantization_config.llm_int8_skip_modules + ): + pass + else: + model._modules[name] = bnb.nn.Linear4bit( + module.in_features, + module.out_features, + module.bias is not None, + quantization_config.bnb_4bit_compute_dtype, + compress_statistics=quantization_config.bnb_4bit_use_double_quant, + quant_type=quantization_config.bnb_4bit_quant_type, + ) + # Force requires grad to False to avoid unexpected errors + model._modules[name].requires_grad_(False) + # Remove the last key for recursion + if len(list(module.children())) > 0: + replace_with_bnb_linear( + module, + modules_to_not_convert, + current_key_name, + quantization_config, + ) + return model + + +# For backward compatibility +def replace_8bit_linear(*args, **kwargs): + warnings.warn( + "`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead", + FutureWarning, + ) + return replace_with_bnb_linear(*args, **kwargs) + + +# For backward compatiblity +def set_module_8bit_tensor_to_device(*args, **kwargs): + warnings.warn( + "`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead", + FutureWarning, + ) + return set_module_quantized_tensor_to_device(*args, **kwargs) + + +def get_keys_to_not_convert(model): + r""" + An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules + we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want + to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in + int8. + + Parameters: + model (`torch.nn.Module`): + Input model + """ + # Create a copy of the model and tie the weights, then + # check if it contains tied weights + tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager` + tied_model.tie_weights() + + tied_params = find_tied_parameters(tied_model) + # For compatibility with Accelerate < 0.18 + if isinstance(tied_params, dict): + tied_keys = list(tied_params.values()) + else: + tied_keys = sum([x[1:] for x in tied_params], []) + has_tied_params = len(tied_keys) > 0 + + # Check if it is a base model + is_base_model = not hasattr(model, model.base_model_prefix) + + # Ignore this for base models (BertModel, GPT2Model, etc.) + if (not has_tied_params) and is_base_model: + return [] + + # otherwise they have an attached head + list_modules = list(model.named_parameters()) + list_last_module = [list_modules[-1][0]] + + # add last module together with tied weights + intersection = set(list_last_module) - set(tied_keys) + list_untouched = tied_keys + list(intersection) + + # remove ".weight" from the keys + names_to_remove = [".weight", ".bias"] + filtered_module_names = [] + for name in list_untouched: + for name_to_remove in names_to_remove: + if name_to_remove in name: + name = name.replace(name_to_remove, "") + filtered_module_names.append(name) + + return filtered_module_names + +#!/usr/bin/env python +# coding=utf-8 + +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + + +if is_torch_available(): + import torch + + +@dataclass +class BitsAndBytesConfig: + """ + This is a wrapper class about all possible attributes and features that you can play with a model that has been + loaded using `bitsandbytes`. + + This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive. + + Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`, + then more arguments will be added to this class. + + Args: + load_in_8bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 8-bit quantization with LLM.int8(). + load_in_4bit (`bool`, *optional*, defaults to `False`): + This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from + `bitsandbytes`. + llm_int8_threshold (`float`, *optional*, defaults to 6): + This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix + Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value + that is above this threshold will be considered an outlier and the operation on those values will be done + in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but + there are some exceptional systematic outliers that are very differently distributed for large models. + These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of + magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6, + but a lower threshold might be needed for more unstable models (small models, fine-tuning). + llm_int8_skip_modules (`List[str]`, *optional*): + An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as + Jukebox that has several heads in different places and not necessarily at the last position. For example + for `CausalLM` models, the last `lm_head` is kept in its original `dtype`. + llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`): + This flag is used for advanced use cases and users that are aware of this feature. If you want to split + your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use + this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8 + operations will not be run on CPU. + llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`): + This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not + have to be converted back and forth for the backward pass. + bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`): + This sets the computational type which might be different than the input time. For example, inputs might be + fp32, but computation can be set to bf16 for speedups. + bnb_4bit_quant_type (`str`, {fp4, fn4}, defaults to `fp4`): + This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types + which are specified by `fp4` or `fn4`. + bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`): + This flag is used for nested quantization where the quantization constants from the first quantization are + quantized again. + kwargs (`Dict[str, Any]`, *optional*): + Additional parameters from which to initialize the configuration object. + """ + + def __init__( + self, + load_in_8bit=False, + load_in_4bit=False, + llm_int8_threshold=6.0, + llm_int8_skip_modules=None, + llm_int8_enable_fp32_cpu_offload=False, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=None, + bnb_4bit_quant_type="fp4", + bnb_4bit_use_double_quant=False, + **kwargs, + ): + self.load_in_8bit = load_in_8bit + self.load_in_4bit = load_in_4bit + self.llm_int8_threshold = llm_int8_threshold + self.llm_int8_skip_modules = llm_int8_skip_modules + self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload + self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight + self.bnb_4bit_quant_type = bnb_4bit_quant_type + self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant + + if bnb_4bit_compute_dtype is None: + self.bnb_4bit_compute_dtype = torch.float32 + elif isinstance(bnb_4bit_compute_dtype, str): + self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype) + elif isinstance(bnb_4bit_compute_dtype, torch.dtype): + self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype + else: + raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype") + + self.post_init() + + def post_init(self): + r""" + Safety checker that arguments are correct - also replaces some NoneType arguments with their default values. + """ + if not isinstance(self.llm_int8_threshold, float): + raise ValueError("llm_int8_threshold must be a float") + + if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list): + raise ValueError("llm_int8_skip_modules must be a list of strings") + if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool): + raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean") + + if not isinstance(self.llm_int8_has_fp16_weight, bool): + raise ValueError("llm_int8_has_fp16_weight must be a boolean") + + if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype): + raise ValueError("bnb_4bit_compute_dtype must be torch.dtype") + + if not isinstance(self.bnb_4bit_quant_type, str): + raise ValueError("bnb_4bit_quant_type must be a string") + + if not isinstance(self.bnb_4bit_use_double_quant, bool): + raise ValueError("bnb_4bit_use_double_quant must be a boolean") + + if self.load_in_4bit and not version.parse(importlib_metadata.version("bitsandbytes")) >= version.parse( + "0.39.0" + ): + raise ValueError( + "4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version" + ) + + def is_quantizable(self): + r""" + Returns `True` if the model is quantizable, `False` otherwise. + """ + return self.load_in_8bit or self.load_in_4bit + + def quantization_method(self): + r""" + This method returns the quantization method used for the model. If the model is not quantizable, it returns + `None`. + """ + if self.load_in_8bit: + return "llm_int8" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4": + return "fp4" + elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4": + return "nf4" + else: + return None + + @classmethod + def from_dict(cls, config_dict, return_unused_kwargs, **kwargs): + """ + Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. + return_unused_kwargs (`bool`): + Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in + `PreTrainedModel`. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`BitsAndBytesConfig`]: The configuration object instantiated from those parameters. + """ + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + if return_unused_kwargs: + return config, kwargs + else: + return config + + def to_json_file(self, json_file_path: Union[str, os.PathLike]): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default + `BitsAndBytesConfig()` is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + config_dict = self.to_dict() + json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + writer.write(json_string) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + + output = copy.deepcopy(self.__dict__) + output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1] + + return output \ No newline at end of file diff --git a/utils/lora.py b/utils/lora.py new file mode 100644 index 0000000..198f1f3 --- /dev/null +++ b/utils/lora.py @@ -0,0 +1,152 @@ +# Adapted from https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/module/lora.py + +import math +import torch +from torch import nn +import torch.nn.functional as F + +class LinearLayer_LoRA(nn.Module): + # a simple implementation of LoRA + def __init__(self, + weight, + lora_dim=0, + lora_scaling=1, + lora_droppout=0, + bias=None): + super(LinearLayer_LoRA, self).__init__() + self.weight = weight + self.bias = bias + + if lora_dim <= 0: + raise ValueError( + "You are training to use LoRA, whose reduced dim should be larger than 1" + ) + + rows, columns = weight.shape + self.lora_right_weight = nn.Parameter(torch.zeros( + columns, + lora_dim)) # apply transpose so in forward we do not need to + self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows)) + self.lora_scaling = lora_scaling / lora_dim + + if lora_droppout > 0: + self.lora_dropout = nn.Dropout(lora_droppout) + else: + self.lora_dropout = nn.Identity() + + self.reset_parameters() + # disable the original weight gradient + self.weight.requires_grad = False + # fuse LoRA to the original weight + self.fuse_lora = False + + def eval(self): + self.lora_dropout.eval() + + # self.fuse_lora_weight() + + def train(self, mode=True): + self.lora_dropout.train(mode) + # self.unfuse_lora_weight() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5)) + nn.init.zeros_(self.lora_left_weight) + + def fuse_lora_weight(self): + if not self.fuse_lora: + self.weight.data += self.lora_scaling * torch.matmul( + self.lora_left_weight.t(), self.lora_right_weight.t()) + self.fuse_lora = True + + def unfuse_lora_weight(self): + if self.fuse_lora: + self.weight.data -= self.lora_scaling * torch.matmul( + self.lora_left_weight.t(), self.lora_right_weight.t()) + self.fuse_lora = False + + def forward(self, input): + if self.fuse_lora: + return F.linear(input, self.weight, self.bias) + else: + return F.linear( + input, self.weight, + self.bias) + (self.lora_dropout(input) @ self.lora_right_weight + @ self.lora_left_weight) * self.lora_scaling + + +def recursive_getattr(model, module_name): + """ + From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py + Recursively get the attribute of a module. + Args: + model (`torch.nn.Module`) + The model to get the attribute from. + module_name (`str`) + The name of the module to get the attribute from. + """ + split_list = module_name.split('.') + output = model + for name in split_list: + output = getattr(output, name) + return output + + +def recursive_setattr(model, module_name, module): + """ + From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py + Recursively set the attribute of a module. + Args: + model (`torch.nn.Module`) + The model to set the attribute in. + module_name (`str`) + The name of the module to set the attribute in. + module (`torch.nn.Module`) + The module to set the attribute to. + """ + split_list = module_name.split('.') + output = model + for name in split_list[:-1]: + output = getattr(output, name) + output.__setattr__(split_list[-1], module) + + +# convert the linear layer to LoRA +def convert_linear_layer_to_lora(model, + part_module_name, + lora_dim=0, + lora_scaling=1, + lora_droppout=0): + repalce_name = [] + for name, module in model.named_modules(): + if isinstance(module, nn.Linear) and part_module_name in name: + repalce_name.append(name) + for name in repalce_name: + module = recursive_getattr(model, name) + tmp = LinearLayer_LoRA( + module.weight, lora_dim, lora_scaling, lora_droppout, + module.bias).to(module.weight.device).to(module.weight.dtype) + recursive_setattr(model, name, tmp) + return model + + +# convert the LoRA layer to linear layer +def convert_lora_to_linear_layer(model): + repalce_name = [] + for name, module in model.named_modules(): + if isinstance(module, LinearLayer_LoRA): + repalce_name.append(name) + for name in repalce_name: + module = recursive_getattr(model, name) + module.fuse_lora_weight() + return model + + +def only_optimize_lora_parameters(model): + # turn off the gradient of all the parameters except the LoRA parameters + for name, param in model.named_parameters(): + if "lora_right_weight" in name or "lora_left_weight" in name: + param.requires_grad = True + else: + param.requires_grad = False + return model \ No newline at end of file