Add better voice clones and prepare for finetuning

This commit is contained in:
Francis LaBounty
2023-05-25 16:24:41 -06:00
parent 0b16a49fe2
commit 40afeec9c0
12 changed files with 1050 additions and 37 deletions

3
.gitignore vendored
View File

@@ -1,4 +1,5 @@
__pycache__/
*.wav
_temp/
models/
models/
output.npz

View File

@@ -8,6 +8,9 @@ You will get the best results by making generations with your cloned voice until
- [BARK text to speech @ SERP AI](https://serp.ai/tools/bark-text-to-speech-ai-voice-clone-app/)
# Shoutouts
- Huge shoutout to [gitmylo](https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer/) for the solution to the semantic token generation for better voice clones and finetunes
-------------------------------------------------------------------
# Original README.md

View File

@@ -60,12 +60,12 @@ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "serp", "bark_v0")
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
USE_SMALL_MODELS = os.environ.get("SERP_USE_SMALL_MODELS", False)
GLOBAL_ENABLE_MPS = os.environ.get("SERP_ENABLE_MPS", False)
OFFLOAD_CPU = os.environ.get("SERP_OFFLOAD_CPU", False)
REMOTE_MODEL_PATHS = {

View File

@@ -8,6 +8,8 @@ from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat, reduce
SEMANTIC_PAD_TOKEN = 10_000
class LayerNorm(nn.Module):
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
@@ -165,7 +167,7 @@ class GPT(nn.Module):
n_params -= self.transformer.wpe.weight.numel()
return n_params
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
device = idx.device
b, t = idx.size()
if past_kv is not None:
@@ -212,6 +214,21 @@ class GPT(nn.Module):
x = self.transformer.ln_f(x)
if labels is not None:
logits = self.lm_head(x)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = nn.CrossEntropyLoss()
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
return logits, loss
# inference-time mini-optimization: only forward the lm_head on the very last position
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim

View File

@@ -12,7 +12,39 @@
"import torchaudio\n",
"import torch\n",
"\n",
"model = load_codec_model(use_gpu=True)"
"device = 'cuda' # or 'cpu'\n",
"model = load_codec_model(use_gpu=True if device == 'cuda' else False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
"from hubert.hubert_manager import HuBERTManager\n",
"hubert_manager = HuBERTManager()\n",
"hubert_manager.make_sure_hubert_installed()\n",
"hubert_manager.make_sure_tokenizer_installed()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer \n",
"# Load HuBERT for semantic tokens\n",
"from hubert.pre_kmeans_hubert import CustomHubert\n",
"from hubert.customtokenizer import CustomTokenizer\n",
"\n",
"# Load the HuBERT model\n",
"hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(device)\n",
"\n",
"# Load the CustomTokenizer model\n",
"tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(device) # Automatically uses the right layers"
]
},
{
@@ -22,11 +54,20 @@
"outputs": [],
"source": [
"# Load and pre-process the audio waveform\n",
"audio_filepath = 'audio.wav' # the audio you want to clone (will get truncated so 5-10 seconds is probably fine, existing samples that I checked are around 7 seconds)\n",
"device = 'cuda' # or 'cpu'\n",
"audio_filepath = 'audio.wav' # the audio you want to clone (under 13 seconds)\n",
"wav, sr = torchaudio.load(audio_filepath)\n",
"wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
"wav = wav.unsqueeze(0).to(device)"
"wav = wav.to(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)\n",
"semantic_tokens = tokenizer.get_token(semantic_vectors)"
]
},
{
@@ -37,31 +78,10 @@
"source": [
"# Extract discrete codes from EnCodec\n",
"with torch.no_grad():\n",
" encoded_frames = model.encode(wav)\n",
" encoded_frames = model.encode(wav.unsqueeze(0))\n",
"codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"text = \"Transcription of the audio you are cloning\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get seconds of audio\n",
"seconds = wav.shape[-1] / model.sample_rate\n",
"# generate semantic tokens\n",
"semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) # not 100% sure on this part"
]
},
{
"cell_type": "code",
"execution_count": null,
@@ -69,7 +89,9 @@
"outputs": [],
"source": [
"# move codes to cpu\n",
"codes = codes.cpu().numpy()"
"codes = codes.cpu().numpy()\n",
"# move semantic tokens to cpu\n",
"semantic_tokens = semantic_tokens.cpu().numpy()"
]
},
{
@@ -121,10 +143,7 @@
"\n",
"# Enter your prompt and speaker here\n",
"text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
"voice_name = \"speaker_0\" # use your custom voice name here if you have one\n",
"\n",
"# load the tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")"
"voice_name = \"output\" # use your custom voice name here if you have one"
]
},
{

0
hubert/__init__.py Normal file
View File

184
hubert/customtokenizer.py Normal file
View File

@@ -0,0 +1,184 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
import json
import os.path
from zipfile import ZipFile
import numpy
import torch
from torch import nn, optim
from torch.serialization import MAP_LOCATION
class CustomTokenizer(nn.Module):
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
super(CustomTokenizer, self).__init__()
next_size = input_size
if version == 0:
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
next_size = hidden_size
if version == 1:
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
self.intermediate = nn.Linear(hidden_size, 4096)
next_size = 4096
self.fc = nn.Linear(next_size, output_size)
self.softmax = nn.LogSoftmax(dim=1)
self.optimizer: optim.Optimizer = None
self.lossfunc = nn.CrossEntropyLoss()
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.version = version
def forward(self, x):
x, _ = self.lstm(x)
if self.version == 1:
x = self.intermediate(x)
x = self.fc(x)
x = self.softmax(x)
return x
@torch.no_grad()
def get_token(self, x):
"""
Used to get the token for the first
:param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model.
:return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model.
"""
return torch.argmax(self(x), dim=1)
def prepare_training(self):
self.optimizer = optim.Adam(self.parameters(), 0.001)
def train_step(self, x_train, y_train, log_loss=False):
# y_train = y_train[:-1]
# y_train = y_train[1:]
optimizer = self.optimizer
lossfunc = self.lossfunc
# Zero the gradients
self.zero_grad()
# Forward pass
y_pred = self(x_train)
y_train_len = len(y_train)
y_pred_len = y_pred.shape[0]
if y_train_len > y_pred_len:
diff = y_train_len - y_pred_len
y_train = y_train[diff:]
elif y_train_len < y_pred_len:
diff = y_pred_len - y_train_len
y_pred = y_pred[:-diff, :]
y_train_hot = torch.zeros(len(y_train), self.output_size)
y_train_hot[range(len(y_train)), y_train] = 1
y_train_hot = y_train_hot.to('cuda')
# Calculate the loss
loss = lossfunc(y_pred, y_train_hot)
# Print loss
if log_loss:
print('Loss', loss.item())
# Backward pass
loss.backward()
# Update the weights
optimizer.step()
def save(self, path):
info_path = os.path.basename(path) + '/.info'
torch.save(self.state_dict(), path)
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
with ZipFile(path, 'a') as model_zip:
model_zip.writestr(info_path, data_from_model.save())
model_zip.close()
@staticmethod
def load_from_checkpoint(path, map_location: MAP_LOCATION = None):
old = True
with ZipFile(path) as model_zip:
filesMatch = [file for file in model_zip.namelist() if file.endswith('/.info')]
file = filesMatch[0] if filesMatch else None
if file:
old = False
data_from_model = Data.load(model_zip.read(file).decode('utf-8'))
model_zip.close()
if old:
model = CustomTokenizer()
else:
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
model.load_state_dict(torch.load(path, map_location))
return model
class Data:
input_size: int
hidden_size: int
output_size: int
version: int
def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0):
self.input_size = input_size
self.hidden_size = hidden_size
self.output_size = output_size
self.version = version
@staticmethod
def load(string):
data = json.loads(string)
return Data(data['input_size'], data['hidden_size'], data['output_size'], data['version'])
def save(self):
data = {
'input_size': self.input_size,
'hidden_size': self.hidden_size,
'output_size': self.output_size,
'version': self.version,
}
return json.dumps(data)
def auto_train(data_path, save_path='model.pth', load_model: str | None = None, save_epochs=1):
data_x, data_y = [], []
if load_model and os.path.isfile(load_model):
print('Loading model from', load_model)
model_training = CustomTokenizer.load_from_checkpoint(load_model, 'cuda')
else:
print('Creating new model.')
model_training = CustomTokenizer(version=1).to('cuda') # Settings for the model to run without lstm
save_path = os.path.join(data_path, save_path)
base_save_path = '.'.join(save_path.split('.')[:-1])
sem_string = '_semantic.npy'
feat_string = '_semantic_features.npy'
ready = os.path.join(data_path, 'ready')
for input_file in os.listdir(ready):
full_path = os.path.join(ready, input_file)
if input_file.endswith(sem_string):
data_y.append(numpy.load(full_path))
elif input_file.endswith(feat_string):
data_x.append(numpy.load(full_path))
model_training.prepare_training()
epoch = 1
while 1:
for i in range(save_epochs):
j = 0
for x, y in zip(data_x, data_y):
model_training.train_step(torch.tensor(x).to('cuda'), torch.tensor(y).to('cuda'), j % 50 == 0) # Print loss every 50 steps
j += 1
save_p = save_path
save_p_2 = f'{base_save_path}_epoch_{epoch}.pth'
model_training.save(save_p)
model_training.save(save_p_2)
print(f'Epoch {epoch} completed')
epoch += 1

35
hubert/hubert_manager.py Normal file
View File

@@ -0,0 +1,35 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
import os.path
import shutil
import urllib.request
import huggingface_hub
class HuBERTManager:
@staticmethod
def make_sure_hubert_installed(download_url: str = 'https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt', file_name: str = 'hubert.pt'):
install_dir = os.path.join('data', 'models', 'hubert')
if not os.path.isdir(install_dir):
os.makedirs(install_dir, exist_ok=True)
install_file = os.path.join(install_dir, file_name)
if not os.path.isfile(install_file):
print('Downloading HuBERT base model')
urllib.request.urlretrieve(download_url, install_file)
print('Downloaded HuBERT')
return install_file
@staticmethod
def make_sure_tokenizer_installed(model: str = 'quantifier_hubert_base_ls960_14.pth', repo: str = 'GitMylo/bark-voice-cloning', local_file: str = 'tokenizer.pth'):
install_dir = os.path.join('data', 'models', 'hubert')
if not os.path.isdir(install_dir):
os.makedirs(install_dir, exist_ok=True)
install_file = os.path.join(install_dir, local_file)
if not os.path.isfile(install_file):
print('Downloading HuBERT custom tokenizer')
huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False)
shutil.move(os.path.join(install_dir, model), install_file)
print('Downloaded tokenizer')
return install_file

View File

@@ -0,0 +1,94 @@
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
from pathlib import Path
import torch
from torch import nn
from einops import pack, unpack
import joblib
import fairseq
from torchaudio.functional import resample
from audiolm_pytorch.utils import curtail_to_multiple
import logging
logging.root.setLevel(logging.ERROR)
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
class CustomHubert(nn.Module):
"""
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
or you can train your own
"""
def __init__(
self,
checkpoint_path,
target_sample_hz=16000,
seq_len_multiple_of=None,
output_layer=9
):
super().__init__()
self.target_sample_hz = target_sample_hz
self.seq_len_multiple_of = seq_len_multiple_of
self.output_layer = output_layer
model_path = Path(checkpoint_path)
assert model_path.exists(), f'path {checkpoint_path} does not exist'
checkpoint = torch.load(checkpoint_path)
load_model_input = {checkpoint_path: checkpoint}
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
self.model = model[0]
self.model.eval()
@property
def groups(self):
return 1
@torch.no_grad()
def forward(
self,
wav_input,
flatten=True,
input_sample_hz=None
):
device = wav_input.device
if exists(input_sample_hz):
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
if exists(self.seq_len_multiple_of):
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
embed = self.model(
wav_input,
features_only=True,
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
output_layer=self.output_layer
)
embed, packed_shape = pack([embed['x']], '* d')
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
if flatten:
return codebook_indices
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
return codebook_indices

0
utils/__init__.py Normal file
View File

508
utils/bitsandbytes.py Normal file
View File

@@ -0,0 +1,508 @@
# From https://github.com/huggingface/transformers/blob/e45e756d22206ca8fa9fb057c8c3d8fa79bf81c6/src/transformers/utils/bitsandbytes.py
import warnings
import sys
import importlib.util
from copy import deepcopy
import copy
import json
import os
from dataclasses import dataclass
from typing import Any, Tuple, Union, Dict
from packaging import version
if sys.version_info < (3, 8):
import importlib_metadata
else:
import importlib.metadata as importlib_metadata
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
package_exists = importlib.util.find_spec(pkg_name) is not None
package_version = "N/A"
if package_exists:
try:
package_version = importlib_metadata.version(pkg_name)
package_exists = True
except importlib_metadata.PackageNotFoundError:
package_exists = False
if return_version:
return package_exists, package_version
else:
return package_exists
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
_bitsandbytes_available = _is_package_available("bitsandbytes")
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
def is_accelerate_available(check_partial_state=False):
if check_partial_state:
return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.19.0")
return _accelerate_available
def is_bitsandbytes_available():
return _bitsandbytes_available
def is_torch_available():
return _torch_available
if is_bitsandbytes_available():
import bitsandbytes as bnb
import torch
import torch.nn as nn
if is_accelerate_available():
from accelerate import init_empty_weights
from accelerate.utils import find_tied_parameters
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
"""
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
class `Int8Params` from `bitsandbytes`.
Args:
module (`torch.nn.Module`):
The module in which the tensor we want to move lives.
tensor_name (`str`):
The full name of the parameter/buffer.
device (`int`, `str` or `torch.device`):
The device on which to set the tensor.
value (`torch.Tensor`, *optional*):
The value of the tensor (useful when going from the meta device to any other device).
fp16_statistics (`torch.HalfTensor`, *optional*):
The list of fp16 statistics to set on the module, used for serialization.
"""
# Recurse if needed
if "." in tensor_name:
splits = tensor_name.split(".")
for split in splits[:-1]:
new_module = getattr(module, split)
if new_module is None:
raise ValueError(f"{module} has no attribute {split}.")
module = new_module
tensor_name = splits[-1]
if tensor_name not in module._parameters and tensor_name not in module._buffers:
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
is_buffer = tensor_name in module._buffers
old_value = getattr(module, tensor_name)
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
is_4bit = False
is_8bit = False
if is_buffer or not is_bitsandbytes_available():
is_8bit = False
is_4bit = False
else:
is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
if is_8bit or is_4bit:
param = module._parameters[tensor_name]
if param.device.type != "cuda":
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to("cpu")
if value.dtype == torch.int8:
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse(
"0.37.2"
)
if not is_8bit_serializable:
raise ValueError(
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
)
else:
new_value = torch.tensor(value, device="cpu")
kwargs = old_value.__dict__
if is_8bit:
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
elif is_4bit:
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
module._parameters[tensor_name] = new_value
if fp16_statistics is not None:
setattr(module.weight, "SCB", fp16_statistics.to(device))
else:
if value is None:
new_value = old_value.to(device)
elif isinstance(value, torch.Tensor):
new_value = value.to(device)
else:
new_value = torch.tensor(value, device=device)
if is_buffer:
module._buffers[tensor_name] = new_value
else:
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
module._parameters[tensor_name] = new_value
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
"""
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
bitsandbytes`
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
predictive degradation is possible for very large models (>=176B parameters).
Parameters:
model (`torch.nn.Module`):
Input model or `torch.nn.Module` as the function is run recursively.
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
for numerical stability reasons.
current_key_name (`List[`str`]`, *optional*):
An array to track the current key of the recursion. This is used to check whether the current key (part of
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
`disk`).
"""
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
for name, module in model.named_children():
if current_key_name is None:
current_key_name = []
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
# Check if the current key is not in the `modules_to_not_convert`
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
with init_empty_weights():
if quantization_config.quantization_method() == "llm_int8":
model._modules[name] = bnb.nn.Linear8bitLt(
module.in_features,
module.out_features,
module.bias is not None,
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
threshold=quantization_config.llm_int8_threshold,
)
else:
if (
quantization_config.llm_int8_skip_modules is not None
and name in quantization_config.llm_int8_skip_modules
):
pass
else:
model._modules[name] = bnb.nn.Linear4bit(
module.in_features,
module.out_features,
module.bias is not None,
quantization_config.bnb_4bit_compute_dtype,
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
quant_type=quantization_config.bnb_4bit_quant_type,
)
# Force requires grad to False to avoid unexpected errors
model._modules[name].requires_grad_(False)
# Remove the last key for recursion
if len(list(module.children())) > 0:
replace_with_bnb_linear(
module,
modules_to_not_convert,
current_key_name,
quantization_config,
)
return model
# For backward compatibility
def replace_8bit_linear(*args, **kwargs):
warnings.warn(
"`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
FutureWarning,
)
return replace_with_bnb_linear(*args, **kwargs)
# For backward compatiblity
def set_module_8bit_tensor_to_device(*args, **kwargs):
warnings.warn(
"`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
FutureWarning,
)
return set_module_quantized_tensor_to_device(*args, **kwargs)
def get_keys_to_not_convert(model):
r"""
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
int8.
Parameters:
model (`torch.nn.Module`):
Input model
"""
# Create a copy of the model and tie the weights, then
# check if it contains tied weights
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
tied_model.tie_weights()
tied_params = find_tied_parameters(tied_model)
# For compatibility with Accelerate < 0.18
if isinstance(tied_params, dict):
tied_keys = list(tied_params.values())
else:
tied_keys = sum([x[1:] for x in tied_params], [])
has_tied_params = len(tied_keys) > 0
# Check if it is a base model
is_base_model = not hasattr(model, model.base_model_prefix)
# Ignore this for base models (BertModel, GPT2Model, etc.)
if (not has_tied_params) and is_base_model:
return []
# otherwise they have an attached head
list_modules = list(model.named_parameters())
list_last_module = [list_modules[-1][0]]
# add last module together with tied weights
intersection = set(list_last_module) - set(tied_keys)
list_untouched = tied_keys + list(intersection)
# remove ".weight" from the keys
names_to_remove = [".weight", ".bias"]
filtered_module_names = []
for name in list_untouched:
for name_to_remove in names_to_remove:
if name_to_remove in name:
name = name.replace(name_to_remove, "")
filtered_module_names.append(name)
return filtered_module_names
#!/usr/bin/env python
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
if is_torch_available():
import torch
@dataclass
class BitsAndBytesConfig:
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `bitsandbytes`.
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
then more arguments will be added to this class.
Args:
load_in_8bit (`bool`, *optional*, defaults to `False`):
This flag is used to enable 8-bit quantization with LLM.int8().
load_in_4bit (`bool`, *optional*, defaults to `False`):
This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
`bitsandbytes`.
llm_int8_threshold (`float`, *optional*, defaults to 6):
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
that is above this threshold will be considered an outlier and the operation on those values will be done
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
there are some exceptional systematic outliers that are very differently distributed for large models.
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
but a lower threshold might be needed for more unstable models (small models, fine-tuning).
llm_int8_skip_modules (`List[str]`, *optional*):
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
Jukebox that has several heads in different places and not necessarily at the last position. For example
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
This flag is used for advanced use cases and users that are aware of this feature. If you want to split
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
operations will not be run on CPU.
llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
have to be converted back and forth for the backward pass.
bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
This sets the computational type which might be different than the input time. For example, inputs might be
fp32, but computation can be set to bf16 for speedups.
bnb_4bit_quant_type (`str`, {fp4, fn4}, defaults to `fp4`):
This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
which are specified by `fp4` or `fn4`.
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
This flag is used for nested quantization where the quantization constants from the first quantization are
quantized again.
kwargs (`Dict[str, Any]`, *optional*):
Additional parameters from which to initialize the configuration object.
"""
def __init__(
self,
load_in_8bit=False,
load_in_4bit=False,
llm_int8_threshold=6.0,
llm_int8_skip_modules=None,
llm_int8_enable_fp32_cpu_offload=False,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=None,
bnb_4bit_quant_type="fp4",
bnb_4bit_use_double_quant=False,
**kwargs,
):
self.load_in_8bit = load_in_8bit
self.load_in_4bit = load_in_4bit
self.llm_int8_threshold = llm_int8_threshold
self.llm_int8_skip_modules = llm_int8_skip_modules
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
self.bnb_4bit_quant_type = bnb_4bit_quant_type
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
if bnb_4bit_compute_dtype is None:
self.bnb_4bit_compute_dtype = torch.float32
elif isinstance(bnb_4bit_compute_dtype, str):
self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
else:
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
self.post_init()
def post_init(self):
r"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.llm_int8_threshold, float):
raise ValueError("llm_int8_threshold must be a float")
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
raise ValueError("llm_int8_skip_modules must be a list of strings")
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean")
if not isinstance(self.llm_int8_has_fp16_weight, bool):
raise ValueError("llm_int8_has_fp16_weight must be a boolean")
if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be torch.dtype")
if not isinstance(self.bnb_4bit_quant_type, str):
raise ValueError("bnb_4bit_quant_type must be a string")
if not isinstance(self.bnb_4bit_use_double_quant, bool):
raise ValueError("bnb_4bit_use_double_quant must be a boolean")
if self.load_in_4bit and not version.parse(importlib_metadata.version("bitsandbytes")) >= version.parse(
"0.39.0"
):
raise ValueError(
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
)
def is_quantizable(self):
r"""
Returns `True` if the model is quantizable, `False` otherwise.
"""
return self.load_in_8bit or self.load_in_4bit
def quantization_method(self):
r"""
This method returns the quantization method used for the model. If the model is not quantizable, it returns
`None`.
"""
if self.load_in_8bit:
return "llm_int8"
elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
return "fp4"
elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
return "nf4"
else:
return None
@classmethod
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
"""
Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters.
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object.
return_unused_kwargs (`bool`):
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
`PreTrainedModel`.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
Returns:
[`BitsAndBytesConfig`]: The configuration object instantiated from those parameters.
"""
config = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
if return_unused_kwargs:
return config, kwargs
else:
return config
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default
`BitsAndBytesConfig()` is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
config_dict = self.to_dict()
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
writer.write(json_string)
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary. Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
return output

152
utils/lora.py Normal file
View File

@@ -0,0 +1,152 @@
# Adapted from https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/module/lora.py
import math
import torch
from torch import nn
import torch.nn.functional as F
class LinearLayer_LoRA(nn.Module):
# a simple implementation of LoRA
def __init__(self,
weight,
lora_dim=0,
lora_scaling=1,
lora_droppout=0,
bias=None):
super(LinearLayer_LoRA, self).__init__()
self.weight = weight
self.bias = bias
if lora_dim <= 0:
raise ValueError(
"You are training to use LoRA, whose reduced dim should be larger than 1"
)
rows, columns = weight.shape
self.lora_right_weight = nn.Parameter(torch.zeros(
columns,
lora_dim)) # apply transpose so in forward we do not need to
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
self.lora_scaling = lora_scaling / lora_dim
if lora_droppout > 0:
self.lora_dropout = nn.Dropout(lora_droppout)
else:
self.lora_dropout = nn.Identity()
self.reset_parameters()
# disable the original weight gradient
self.weight.requires_grad = False
# fuse LoRA to the original weight
self.fuse_lora = False
def eval(self):
self.lora_dropout.eval()
# self.fuse_lora_weight()
def train(self, mode=True):
self.lora_dropout.train(mode)
# self.unfuse_lora_weight()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_left_weight)
def fuse_lora_weight(self):
if not self.fuse_lora:
self.weight.data += self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = True
def unfuse_lora_weight(self):
if self.fuse_lora:
self.weight.data -= self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = False
def forward(self, input):
if self.fuse_lora:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(
input, self.weight,
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
@ self.lora_left_weight) * self.lora_scaling
def recursive_getattr(model, module_name):
"""
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
Recursively get the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to get the attribute from.
module_name (`str`)
The name of the module to get the attribute from.
"""
split_list = module_name.split('.')
output = model
for name in split_list:
output = getattr(output, name)
return output
def recursive_setattr(model, module_name, module):
"""
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
Recursively set the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to set the attribute in.
module_name (`str`)
The name of the module to set the attribute in.
module (`torch.nn.Module`)
The module to set the attribute to.
"""
split_list = module_name.split('.')
output = model
for name in split_list[:-1]:
output = getattr(output, name)
output.__setattr__(split_list[-1], module)
# convert the linear layer to LoRA
def convert_linear_layer_to_lora(model,
part_module_name,
lora_dim=0,
lora_scaling=1,
lora_droppout=0):
repalce_name = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and part_module_name in name:
repalce_name.append(name)
for name in repalce_name:
module = recursive_getattr(model, name)
tmp = LinearLayer_LoRA(
module.weight, lora_dim, lora_scaling, lora_droppout,
module.bias).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
return model
# convert the LoRA layer to linear layer
def convert_lora_to_linear_layer(model):
repalce_name = []
for name, module in model.named_modules():
if isinstance(module, LinearLayer_LoRA):
repalce_name.append(name)
for name in repalce_name:
module = recursive_getattr(model, name)
module.fuse_lora_weight()
return model
def only_optimize_lora_parameters(model):
# turn off the gradient of all the parameters except the LoRA parameters
for name, param in model.named_parameters():
if "lora_right_weight" in name or "lora_left_weight" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model