mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-14 18:57:56 +01:00
Add better voice clones and prepare for finetuning
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -1,4 +1,5 @@
|
||||
__pycache__/
|
||||
*.wav
|
||||
_temp/
|
||||
models/
|
||||
models/
|
||||
output.npz
|
||||
@@ -8,6 +8,9 @@ You will get the best results by making generations with your cloned voice until
|
||||
|
||||
- [BARK text to speech @ SERP AI](https://serp.ai/tools/bark-text-to-speech-ai-voice-clone-app/)
|
||||
|
||||
# Shoutouts
|
||||
- Huge shoutout to [gitmylo](https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer/) for the solution to the semantic token generation for better voice clones and finetunes
|
||||
|
||||
-------------------------------------------------------------------
|
||||
# Original README.md
|
||||
|
||||
|
||||
@@ -60,12 +60,12 @@ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
|
||||
default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
|
||||
CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "serp", "bark_v0")
|
||||
|
||||
|
||||
USE_SMALL_MODELS = os.environ.get("SUNO_USE_SMALL_MODELS", False)
|
||||
GLOBAL_ENABLE_MPS = os.environ.get("SUNO_ENABLE_MPS", False)
|
||||
OFFLOAD_CPU = os.environ.get("SUNO_OFFLOAD_CPU", False)
|
||||
USE_SMALL_MODELS = os.environ.get("SERP_USE_SMALL_MODELS", False)
|
||||
GLOBAL_ENABLE_MPS = os.environ.get("SERP_ENABLE_MPS", False)
|
||||
OFFLOAD_CPU = os.environ.get("SERP_OFFLOAD_CPU", False)
|
||||
|
||||
|
||||
REMOTE_MODEL_PATHS = {
|
||||
|
||||
@@ -8,6 +8,8 @@ from dataclasses import dataclass
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from einops import rearrange, repeat, reduce
|
||||
SEMANTIC_PAD_TOKEN = 10_000
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
|
||||
@@ -165,7 +167,7 @@ class GPT(nn.Module):
|
||||
n_params -= self.transformer.wpe.weight.numel()
|
||||
return n_params
|
||||
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
|
||||
def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False, labels=None):
|
||||
device = idx.device
|
||||
b, t = idx.size()
|
||||
if past_kv is not None:
|
||||
@@ -212,6 +214,21 @@ class GPT(nn.Module):
|
||||
|
||||
x = self.transformer.ln_f(x)
|
||||
|
||||
|
||||
if labels is not None:
|
||||
logits = self.lm_head(x)
|
||||
# Shift so that tokens < n predict n
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
loss_fct = nn.CrossEntropyLoss()
|
||||
shift_logits = shift_logits.view(-1, self.config.output_vocab_size)
|
||||
shift_labels = shift_labels.view(-1)
|
||||
# Enable model parallelism
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
loss = loss_fct(shift_logits, shift_labels)
|
||||
return logits, loss
|
||||
|
||||
# inference-time mini-optimization: only forward the lm_head on the very last position
|
||||
logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
|
||||
|
||||
|
||||
@@ -12,7 +12,39 @@
|
||||
"import torchaudio\n",
|
||||
"import torch\n",
|
||||
"\n",
|
||||
"model = load_codec_model(use_gpu=True)"
|
||||
"device = 'cuda' # or 'cpu'\n",
|
||||
"model = load_codec_model(use_gpu=True if device == 'cuda' else False)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
|
||||
"from hubert.hubert_manager import HuBERTManager\n",
|
||||
"hubert_manager = HuBERTManager()\n",
|
||||
"hubert_manager.make_sure_hubert_installed()\n",
|
||||
"hubert_manager.make_sure_tokenizer_installed()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer \n",
|
||||
"# Load HuBERT for semantic tokens\n",
|
||||
"from hubert.pre_kmeans_hubert import CustomHubert\n",
|
||||
"from hubert.customtokenizer import CustomTokenizer\n",
|
||||
"\n",
|
||||
"# Load the HuBERT model\n",
|
||||
"hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(device)\n",
|
||||
"\n",
|
||||
"# Load the CustomTokenizer model\n",
|
||||
"tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(device) # Automatically uses the right layers"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -22,11 +54,20 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Load and pre-process the audio waveform\n",
|
||||
"audio_filepath = 'audio.wav' # the audio you want to clone (will get truncated so 5-10 seconds is probably fine, existing samples that I checked are around 7 seconds)\n",
|
||||
"device = 'cuda' # or 'cpu'\n",
|
||||
"audio_filepath = 'audio.wav' # the audio you want to clone (under 13 seconds)\n",
|
||||
"wav, sr = torchaudio.load(audio_filepath)\n",
|
||||
"wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
|
||||
"wav = wav.unsqueeze(0).to(device)"
|
||||
"wav = wav.to(device)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)\n",
|
||||
"semantic_tokens = tokenizer.get_token(semantic_vectors)"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -37,31 +78,10 @@
|
||||
"source": [
|
||||
"# Extract discrete codes from EnCodec\n",
|
||||
"with torch.no_grad():\n",
|
||||
" encoded_frames = model.encode(wav)\n",
|
||||
" encoded_frames = model.encode(wav.unsqueeze(0))\n",
|
||||
"codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"text = \"Transcription of the audio you are cloning\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# get seconds of audio\n",
|
||||
"seconds = wav.shape[-1] / model.sample_rate\n",
|
||||
"# generate semantic tokens\n",
|
||||
"semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7) # not 100% sure on this part"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
@@ -69,7 +89,9 @@
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# move codes to cpu\n",
|
||||
"codes = codes.cpu().numpy()"
|
||||
"codes = codes.cpu().numpy()\n",
|
||||
"# move semantic tokens to cpu\n",
|
||||
"semantic_tokens = semantic_tokens.cpu().numpy()"
|
||||
]
|
||||
},
|
||||
{
|
||||
@@ -121,10 +143,7 @@
|
||||
"\n",
|
||||
"# Enter your prompt and speaker here\n",
|
||||
"text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
|
||||
"voice_name = \"speaker_0\" # use your custom voice name here if you have one\n",
|
||||
"\n",
|
||||
"# load the tokenizer\n",
|
||||
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")"
|
||||
"voice_name = \"output\" # use your custom voice name here if you have one"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
||||
0
hubert/__init__.py
Normal file
0
hubert/__init__.py
Normal file
184
hubert/customtokenizer.py
Normal file
184
hubert/customtokenizer.py
Normal file
@@ -0,0 +1,184 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
import json
|
||||
import os.path
|
||||
from zipfile import ZipFile
|
||||
|
||||
import numpy
|
||||
import torch
|
||||
from torch import nn, optim
|
||||
from torch.serialization import MAP_LOCATION
|
||||
|
||||
|
||||
class CustomTokenizer(nn.Module):
|
||||
def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
|
||||
super(CustomTokenizer, self).__init__()
|
||||
next_size = input_size
|
||||
if version == 0:
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
|
||||
next_size = hidden_size
|
||||
if version == 1:
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
|
||||
self.intermediate = nn.Linear(hidden_size, 4096)
|
||||
next_size = 4096
|
||||
|
||||
self.fc = nn.Linear(next_size, output_size)
|
||||
self.softmax = nn.LogSoftmax(dim=1)
|
||||
self.optimizer: optim.Optimizer = None
|
||||
self.lossfunc = nn.CrossEntropyLoss()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.version = version
|
||||
|
||||
def forward(self, x):
|
||||
x, _ = self.lstm(x)
|
||||
if self.version == 1:
|
||||
x = self.intermediate(x)
|
||||
x = self.fc(x)
|
||||
x = self.softmax(x)
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
def get_token(self, x):
|
||||
"""
|
||||
Used to get the token for the first
|
||||
:param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model.
|
||||
:return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model.
|
||||
"""
|
||||
return torch.argmax(self(x), dim=1)
|
||||
|
||||
def prepare_training(self):
|
||||
self.optimizer = optim.Adam(self.parameters(), 0.001)
|
||||
|
||||
def train_step(self, x_train, y_train, log_loss=False):
|
||||
# y_train = y_train[:-1]
|
||||
# y_train = y_train[1:]
|
||||
|
||||
optimizer = self.optimizer
|
||||
lossfunc = self.lossfunc
|
||||
# Zero the gradients
|
||||
self.zero_grad()
|
||||
|
||||
# Forward pass
|
||||
y_pred = self(x_train)
|
||||
|
||||
y_train_len = len(y_train)
|
||||
y_pred_len = y_pred.shape[0]
|
||||
|
||||
if y_train_len > y_pred_len:
|
||||
diff = y_train_len - y_pred_len
|
||||
y_train = y_train[diff:]
|
||||
elif y_train_len < y_pred_len:
|
||||
diff = y_pred_len - y_train_len
|
||||
y_pred = y_pred[:-diff, :]
|
||||
|
||||
y_train_hot = torch.zeros(len(y_train), self.output_size)
|
||||
y_train_hot[range(len(y_train)), y_train] = 1
|
||||
y_train_hot = y_train_hot.to('cuda')
|
||||
|
||||
# Calculate the loss
|
||||
loss = lossfunc(y_pred, y_train_hot)
|
||||
|
||||
# Print loss
|
||||
if log_loss:
|
||||
print('Loss', loss.item())
|
||||
|
||||
# Backward pass
|
||||
loss.backward()
|
||||
|
||||
# Update the weights
|
||||
optimizer.step()
|
||||
|
||||
def save(self, path):
|
||||
info_path = os.path.basename(path) + '/.info'
|
||||
torch.save(self.state_dict(), path)
|
||||
data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
|
||||
with ZipFile(path, 'a') as model_zip:
|
||||
model_zip.writestr(info_path, data_from_model.save())
|
||||
model_zip.close()
|
||||
|
||||
@staticmethod
|
||||
def load_from_checkpoint(path, map_location: MAP_LOCATION = None):
|
||||
old = True
|
||||
with ZipFile(path) as model_zip:
|
||||
filesMatch = [file for file in model_zip.namelist() if file.endswith('/.info')]
|
||||
file = filesMatch[0] if filesMatch else None
|
||||
if file:
|
||||
old = False
|
||||
data_from_model = Data.load(model_zip.read(file).decode('utf-8'))
|
||||
model_zip.close()
|
||||
if old:
|
||||
model = CustomTokenizer()
|
||||
else:
|
||||
model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
|
||||
model.load_state_dict(torch.load(path, map_location))
|
||||
return model
|
||||
|
||||
|
||||
|
||||
class Data:
|
||||
input_size: int
|
||||
hidden_size: int
|
||||
output_size: int
|
||||
version: int
|
||||
|
||||
def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0):
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.version = version
|
||||
|
||||
@staticmethod
|
||||
def load(string):
|
||||
data = json.loads(string)
|
||||
return Data(data['input_size'], data['hidden_size'], data['output_size'], data['version'])
|
||||
|
||||
def save(self):
|
||||
data = {
|
||||
'input_size': self.input_size,
|
||||
'hidden_size': self.hidden_size,
|
||||
'output_size': self.output_size,
|
||||
'version': self.version,
|
||||
}
|
||||
return json.dumps(data)
|
||||
|
||||
|
||||
def auto_train(data_path, save_path='model.pth', load_model: str | None = None, save_epochs=1):
|
||||
data_x, data_y = [], []
|
||||
|
||||
if load_model and os.path.isfile(load_model):
|
||||
print('Loading model from', load_model)
|
||||
model_training = CustomTokenizer.load_from_checkpoint(load_model, 'cuda')
|
||||
else:
|
||||
print('Creating new model.')
|
||||
model_training = CustomTokenizer(version=1).to('cuda') # Settings for the model to run without lstm
|
||||
save_path = os.path.join(data_path, save_path)
|
||||
base_save_path = '.'.join(save_path.split('.')[:-1])
|
||||
|
||||
sem_string = '_semantic.npy'
|
||||
feat_string = '_semantic_features.npy'
|
||||
|
||||
ready = os.path.join(data_path, 'ready')
|
||||
for input_file in os.listdir(ready):
|
||||
full_path = os.path.join(ready, input_file)
|
||||
if input_file.endswith(sem_string):
|
||||
data_y.append(numpy.load(full_path))
|
||||
elif input_file.endswith(feat_string):
|
||||
data_x.append(numpy.load(full_path))
|
||||
model_training.prepare_training()
|
||||
|
||||
epoch = 1
|
||||
|
||||
while 1:
|
||||
for i in range(save_epochs):
|
||||
j = 0
|
||||
for x, y in zip(data_x, data_y):
|
||||
model_training.train_step(torch.tensor(x).to('cuda'), torch.tensor(y).to('cuda'), j % 50 == 0) # Print loss every 50 steps
|
||||
j += 1
|
||||
save_p = save_path
|
||||
save_p_2 = f'{base_save_path}_epoch_{epoch}.pth'
|
||||
model_training.save(save_p)
|
||||
model_training.save(save_p_2)
|
||||
print(f'Epoch {epoch} completed')
|
||||
epoch += 1
|
||||
35
hubert/hubert_manager.py
Normal file
35
hubert/hubert_manager.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
import os.path
|
||||
import shutil
|
||||
import urllib.request
|
||||
|
||||
import huggingface_hub
|
||||
|
||||
|
||||
class HuBERTManager:
|
||||
@staticmethod
|
||||
def make_sure_hubert_installed(download_url: str = 'https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt', file_name: str = 'hubert.pt'):
|
||||
install_dir = os.path.join('data', 'models', 'hubert')
|
||||
if not os.path.isdir(install_dir):
|
||||
os.makedirs(install_dir, exist_ok=True)
|
||||
install_file = os.path.join(install_dir, file_name)
|
||||
if not os.path.isfile(install_file):
|
||||
print('Downloading HuBERT base model')
|
||||
urllib.request.urlretrieve(download_url, install_file)
|
||||
print('Downloaded HuBERT')
|
||||
return install_file
|
||||
|
||||
|
||||
@staticmethod
|
||||
def make_sure_tokenizer_installed(model: str = 'quantifier_hubert_base_ls960_14.pth', repo: str = 'GitMylo/bark-voice-cloning', local_file: str = 'tokenizer.pth'):
|
||||
install_dir = os.path.join('data', 'models', 'hubert')
|
||||
if not os.path.isdir(install_dir):
|
||||
os.makedirs(install_dir, exist_ok=True)
|
||||
install_file = os.path.join(install_dir, local_file)
|
||||
if not os.path.isfile(install_file):
|
||||
print('Downloading HuBERT custom tokenizer')
|
||||
huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False)
|
||||
shutil.move(os.path.join(install_dir, model), install_file)
|
||||
print('Downloaded tokenizer')
|
||||
return install_file
|
||||
94
hubert/pre_kmeans_hubert.py
Normal file
94
hubert/pre_kmeans_hubert.py
Normal file
@@ -0,0 +1,94 @@
|
||||
# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import pack, unpack
|
||||
|
||||
import joblib
|
||||
|
||||
import fairseq
|
||||
|
||||
from torchaudio.functional import resample
|
||||
|
||||
from audiolm_pytorch.utils import curtail_to_multiple
|
||||
|
||||
import logging
|
||||
logging.root.setLevel(logging.ERROR)
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
return val if exists(val) else d
|
||||
|
||||
|
||||
class CustomHubert(nn.Module):
|
||||
"""
|
||||
checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
|
||||
or you can train your own
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
checkpoint_path,
|
||||
target_sample_hz=16000,
|
||||
seq_len_multiple_of=None,
|
||||
output_layer=9
|
||||
):
|
||||
super().__init__()
|
||||
self.target_sample_hz = target_sample_hz
|
||||
self.seq_len_multiple_of = seq_len_multiple_of
|
||||
self.output_layer = output_layer
|
||||
|
||||
model_path = Path(checkpoint_path)
|
||||
|
||||
assert model_path.exists(), f'path {checkpoint_path} does not exist'
|
||||
|
||||
checkpoint = torch.load(checkpoint_path)
|
||||
load_model_input = {checkpoint_path: checkpoint}
|
||||
model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
|
||||
|
||||
self.model = model[0]
|
||||
self.model.eval()
|
||||
|
||||
@property
|
||||
def groups(self):
|
||||
return 1
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
wav_input,
|
||||
flatten=True,
|
||||
input_sample_hz=None
|
||||
):
|
||||
device = wav_input.device
|
||||
|
||||
if exists(input_sample_hz):
|
||||
wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
|
||||
|
||||
if exists(self.seq_len_multiple_of):
|
||||
wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
|
||||
|
||||
embed = self.model(
|
||||
wav_input,
|
||||
features_only=True,
|
||||
mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
|
||||
output_layer=self.output_layer
|
||||
)
|
||||
|
||||
embed, packed_shape = pack([embed['x']], '* d')
|
||||
|
||||
# codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
|
||||
|
||||
codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
|
||||
|
||||
if flatten:
|
||||
return codebook_indices
|
||||
|
||||
codebook_indices, = unpack(codebook_indices, packed_shape, '*')
|
||||
return codebook_indices
|
||||
0
utils/__init__.py
Normal file
0
utils/__init__.py
Normal file
508
utils/bitsandbytes.py
Normal file
508
utils/bitsandbytes.py
Normal file
@@ -0,0 +1,508 @@
|
||||
# From https://github.com/huggingface/transformers/blob/e45e756d22206ca8fa9fb057c8c3d8fa79bf81c6/src/transformers/utils/bitsandbytes.py
|
||||
|
||||
import warnings
|
||||
import sys
|
||||
import importlib.util
|
||||
from copy import deepcopy
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
from typing import Any, Tuple, Union, Dict
|
||||
|
||||
from packaging import version
|
||||
|
||||
if sys.version_info < (3, 8):
|
||||
import importlib_metadata
|
||||
else:
|
||||
import importlib.metadata as importlib_metadata
|
||||
|
||||
|
||||
def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[Tuple[bool, str], bool]:
|
||||
# Check we're not importing a "pkg_name" directory somewhere but the actual library by trying to grab the version
|
||||
package_exists = importlib.util.find_spec(pkg_name) is not None
|
||||
package_version = "N/A"
|
||||
if package_exists:
|
||||
try:
|
||||
package_version = importlib_metadata.version(pkg_name)
|
||||
package_exists = True
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
package_exists = False
|
||||
if return_version:
|
||||
return package_exists, package_version
|
||||
else:
|
||||
return package_exists
|
||||
|
||||
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
|
||||
_bitsandbytes_available = _is_package_available("bitsandbytes")
|
||||
_torch_available, _torch_version = _is_package_available("torch", return_version=True)
|
||||
|
||||
def is_accelerate_available(check_partial_state=False):
|
||||
if check_partial_state:
|
||||
return _accelerate_available and version.parse(_accelerate_version) >= version.parse("0.19.0")
|
||||
return _accelerate_available
|
||||
|
||||
def is_bitsandbytes_available():
|
||||
return _bitsandbytes_available
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
|
||||
if is_bitsandbytes_available():
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
from accelerate.utils import find_tied_parameters
|
||||
|
||||
|
||||
def set_module_quantized_tensor_to_device(module, tensor_name, device, value=None, fp16_statistics=None):
|
||||
"""
|
||||
A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
|
||||
`param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). The
|
||||
function is adapted from `set_module_tensor_to_device` function from accelerate that is adapted to support the
|
||||
class `Int8Params` from `bitsandbytes`.
|
||||
|
||||
Args:
|
||||
module (`torch.nn.Module`):
|
||||
The module in which the tensor we want to move lives.
|
||||
tensor_name (`str`):
|
||||
The full name of the parameter/buffer.
|
||||
device (`int`, `str` or `torch.device`):
|
||||
The device on which to set the tensor.
|
||||
value (`torch.Tensor`, *optional*):
|
||||
The value of the tensor (useful when going from the meta device to any other device).
|
||||
fp16_statistics (`torch.HalfTensor`, *optional*):
|
||||
The list of fp16 statistics to set on the module, used for serialization.
|
||||
"""
|
||||
# Recurse if needed
|
||||
if "." in tensor_name:
|
||||
splits = tensor_name.split(".")
|
||||
for split in splits[:-1]:
|
||||
new_module = getattr(module, split)
|
||||
if new_module is None:
|
||||
raise ValueError(f"{module} has no attribute {split}.")
|
||||
module = new_module
|
||||
tensor_name = splits[-1]
|
||||
|
||||
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
||||
raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
|
||||
is_buffer = tensor_name in module._buffers
|
||||
old_value = getattr(module, tensor_name)
|
||||
|
||||
if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
|
||||
raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
|
||||
|
||||
is_4bit = False
|
||||
is_8bit = False
|
||||
if is_buffer or not is_bitsandbytes_available():
|
||||
is_8bit = False
|
||||
is_4bit = False
|
||||
else:
|
||||
is_4bit = hasattr(bnb.nn, "Params4bit") and isinstance(module._parameters[tensor_name], bnb.nn.Params4bit)
|
||||
is_8bit = isinstance(module._parameters[tensor_name], bnb.nn.Int8Params)
|
||||
|
||||
if is_8bit or is_4bit:
|
||||
param = module._parameters[tensor_name]
|
||||
if param.device.type != "cuda":
|
||||
if value is None:
|
||||
new_value = old_value.to(device)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to("cpu")
|
||||
if value.dtype == torch.int8:
|
||||
is_8bit_serializable = version.parse(importlib_metadata.version("bitsandbytes")) > version.parse(
|
||||
"0.37.2"
|
||||
)
|
||||
if not is_8bit_serializable:
|
||||
raise ValueError(
|
||||
"Detected int8 weights but the version of bitsandbytes is not compatible with int8 serialization. "
|
||||
"Make sure to download the latest `bitsandbytes` version. `pip install --upgrade bitsandbytes`."
|
||||
)
|
||||
else:
|
||||
new_value = torch.tensor(value, device="cpu")
|
||||
|
||||
kwargs = old_value.__dict__
|
||||
if is_8bit:
|
||||
new_value = bnb.nn.Int8Params(new_value, requires_grad=False, **kwargs).to(device)
|
||||
elif is_4bit:
|
||||
new_value = bnb.nn.Params4bit(new_value, requires_grad=False, **kwargs).to(device)
|
||||
|
||||
module._parameters[tensor_name] = new_value
|
||||
if fp16_statistics is not None:
|
||||
setattr(module.weight, "SCB", fp16_statistics.to(device))
|
||||
|
||||
else:
|
||||
if value is None:
|
||||
new_value = old_value.to(device)
|
||||
elif isinstance(value, torch.Tensor):
|
||||
new_value = value.to(device)
|
||||
else:
|
||||
new_value = torch.tensor(value, device=device)
|
||||
|
||||
if is_buffer:
|
||||
module._buffers[tensor_name] = new_value
|
||||
else:
|
||||
new_value = nn.Parameter(new_value, requires_grad=old_value.requires_grad)
|
||||
module._parameters[tensor_name] = new_value
|
||||
|
||||
|
||||
def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name=None, quantization_config=None):
|
||||
"""
|
||||
A helper function to replace all `torch.nn.Linear` modules by `bnb.nn.Linear8bit` modules from the `bitsandbytes`
|
||||
library. This will enable running your models using mixed int8 precision as described by the paper `LLM.int8():
|
||||
8-bit Matrix Multiplication for Transformers at Scale`. Make sure `bitsandbytes` compiled with the correct CUDA
|
||||
version of your hardware is installed before running this function. `pip install -i https://test.pypi.org/simple/
|
||||
bitsandbytes`
|
||||
|
||||
The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
|
||||
be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
|
||||
CPU/GPU memory is required to run this function. Int8 mixed-precision matrix decomposition works by separating a
|
||||
matrix multiplication into two streams: (1) and systematic feature outlier stream matrix multiplied in fp16
|
||||
(0.01%), (2) a regular stream of int8 matrix multiplication (99.9%). With this method, int8 inference with no
|
||||
predictive degradation is possible for very large models (>=176B parameters).
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model or `torch.nn.Module` as the function is run recursively.
|
||||
modules_to_not_convert (`List[`str`]`, *optional*, defaults to `["lm_head"]`):
|
||||
Names of the modules to not convert in `Linear8bitLt`. In practice we keep the `lm_head` in full precision
|
||||
for numerical stability reasons.
|
||||
current_key_name (`List[`str`]`, *optional*):
|
||||
An array to track the current key of the recursion. This is used to check whether the current key (part of
|
||||
it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
|
||||
`disk`).
|
||||
"""
|
||||
modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
|
||||
for name, module in model.named_children():
|
||||
if current_key_name is None:
|
||||
current_key_name = []
|
||||
|
||||
if isinstance(module, nn.Linear) and name not in modules_to_not_convert:
|
||||
# Check if the current key is not in the `modules_to_not_convert`
|
||||
if not any(key in ".".join(current_key_name) for key in modules_to_not_convert):
|
||||
with init_empty_weights():
|
||||
if quantization_config.quantization_method() == "llm_int8":
|
||||
model._modules[name] = bnb.nn.Linear8bitLt(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
has_fp16_weights=quantization_config.llm_int8_has_fp16_weight,
|
||||
threshold=quantization_config.llm_int8_threshold,
|
||||
)
|
||||
else:
|
||||
if (
|
||||
quantization_config.llm_int8_skip_modules is not None
|
||||
and name in quantization_config.llm_int8_skip_modules
|
||||
):
|
||||
pass
|
||||
else:
|
||||
model._modules[name] = bnb.nn.Linear4bit(
|
||||
module.in_features,
|
||||
module.out_features,
|
||||
module.bias is not None,
|
||||
quantization_config.bnb_4bit_compute_dtype,
|
||||
compress_statistics=quantization_config.bnb_4bit_use_double_quant,
|
||||
quant_type=quantization_config.bnb_4bit_quant_type,
|
||||
)
|
||||
# Force requires grad to False to avoid unexpected errors
|
||||
model._modules[name].requires_grad_(False)
|
||||
# Remove the last key for recursion
|
||||
if len(list(module.children())) > 0:
|
||||
replace_with_bnb_linear(
|
||||
module,
|
||||
modules_to_not_convert,
|
||||
current_key_name,
|
||||
quantization_config,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
# For backward compatibility
|
||||
def replace_8bit_linear(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"`replace_8bit_linear` will be deprecated in a future version, please use `replace_with_bnb_linear` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
return replace_with_bnb_linear(*args, **kwargs)
|
||||
|
||||
|
||||
# For backward compatiblity
|
||||
def set_module_8bit_tensor_to_device(*args, **kwargs):
|
||||
warnings.warn(
|
||||
"`set_module_8bit_tensor_to_device` will be deprecated in a future version, please use `set_module_quantized_tensor_to_device` instead",
|
||||
FutureWarning,
|
||||
)
|
||||
return set_module_quantized_tensor_to_device(*args, **kwargs)
|
||||
|
||||
|
||||
def get_keys_to_not_convert(model):
|
||||
r"""
|
||||
An utility function to get the key of the module to keep in full precision if any For example for CausalLM modules
|
||||
we may want to keep the lm_head in full precision for numerical stability reasons. For other architectures, we want
|
||||
to keep the tied weights of the model. The function will return a list of the keys of the modules to not convert in
|
||||
int8.
|
||||
|
||||
Parameters:
|
||||
model (`torch.nn.Module`):
|
||||
Input model
|
||||
"""
|
||||
# Create a copy of the model and tie the weights, then
|
||||
# check if it contains tied weights
|
||||
tied_model = deepcopy(model) # this has 0 cost since it is done inside `init_empty_weights` context manager`
|
||||
tied_model.tie_weights()
|
||||
|
||||
tied_params = find_tied_parameters(tied_model)
|
||||
# For compatibility with Accelerate < 0.18
|
||||
if isinstance(tied_params, dict):
|
||||
tied_keys = list(tied_params.values())
|
||||
else:
|
||||
tied_keys = sum([x[1:] for x in tied_params], [])
|
||||
has_tied_params = len(tied_keys) > 0
|
||||
|
||||
# Check if it is a base model
|
||||
is_base_model = not hasattr(model, model.base_model_prefix)
|
||||
|
||||
# Ignore this for base models (BertModel, GPT2Model, etc.)
|
||||
if (not has_tied_params) and is_base_model:
|
||||
return []
|
||||
|
||||
# otherwise they have an attached head
|
||||
list_modules = list(model.named_parameters())
|
||||
list_last_module = [list_modules[-1][0]]
|
||||
|
||||
# add last module together with tied weights
|
||||
intersection = set(list_last_module) - set(tied_keys)
|
||||
list_untouched = tied_keys + list(intersection)
|
||||
|
||||
# remove ".weight" from the keys
|
||||
names_to_remove = [".weight", ".bias"]
|
||||
filtered_module_names = []
|
||||
for name in list_untouched:
|
||||
for name_to_remove in names_to_remove:
|
||||
if name_to_remove in name:
|
||||
name = name.replace(name_to_remove, "")
|
||||
filtered_module_names.append(name)
|
||||
|
||||
return filtered_module_names
|
||||
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
class BitsAndBytesConfig:
|
||||
"""
|
||||
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
||||
loaded using `bitsandbytes`.
|
||||
|
||||
This replaces `load_in_8bit` or `load_in_4bit`therefore both options are mutually exclusive.
|
||||
|
||||
Currently only supports `LLM.int8()`, `FP4`, and `NF4` quantization. If more methods are added to `bitsandbytes`,
|
||||
then more arguments will be added to this class.
|
||||
|
||||
Args:
|
||||
load_in_8bit (`bool`, *optional*, defaults to `False`):
|
||||
This flag is used to enable 8-bit quantization with LLM.int8().
|
||||
load_in_4bit (`bool`, *optional*, defaults to `False`):
|
||||
This flag is used to enable 4-bit quantization by replacing the Linear layers with FP4/NF4 layers from
|
||||
`bitsandbytes`.
|
||||
llm_int8_threshold (`float`, *optional*, defaults to 6):
|
||||
This corresponds to the outlier threshold for outlier detection as described in `LLM.int8() : 8-bit Matrix
|
||||
Multiplication for Transformers at Scale` paper: https://arxiv.org/abs/2208.07339 Any hidden states value
|
||||
that is above this threshold will be considered an outlier and the operation on those values will be done
|
||||
in fp16. Values are usually normally distributed, that is, most values are in the range [-3.5, 3.5], but
|
||||
there are some exceptional systematic outliers that are very differently distributed for large models.
|
||||
These outliers are often in the interval [-60, -6] or [6, 60]. Int8 quantization works well for values of
|
||||
magnitude ~5, but beyond that, there is a significant performance penalty. A good default threshold is 6,
|
||||
but a lower threshold might be needed for more unstable models (small models, fine-tuning).
|
||||
llm_int8_skip_modules (`List[str]`, *optional*):
|
||||
An explicit list of the modules that we do not want to convert in 8-bit. This is useful for models such as
|
||||
Jukebox that has several heads in different places and not necessarily at the last position. For example
|
||||
for `CausalLM` models, the last `lm_head` is kept in its original `dtype`.
|
||||
llm_int8_enable_fp32_cpu_offload (`bool`, *optional*, defaults to `False`):
|
||||
This flag is used for advanced use cases and users that are aware of this feature. If you want to split
|
||||
your model in different parts and run some parts in int8 on GPU and some parts in fp32 on CPU, you can use
|
||||
this flag. This is useful for offloading large models such as `google/flan-t5-xxl`. Note that the int8
|
||||
operations will not be run on CPU.
|
||||
llm_int8_has_fp16_weight (`bool`, *optional*, defaults to `False`):
|
||||
This flag runs LLM.int8() with 16-bit main weights. This is useful for fine-tuning as the weights do not
|
||||
have to be converted back and forth for the backward pass.
|
||||
bnb_4bit_compute_dtype (`torch.dtype` or str, *optional*, defaults to `torch.float32`):
|
||||
This sets the computational type which might be different than the input time. For example, inputs might be
|
||||
fp32, but computation can be set to bf16 for speedups.
|
||||
bnb_4bit_quant_type (`str`, {fp4, fn4}, defaults to `fp4`):
|
||||
This sets the quantization data type in the bnb.nn.Linear4Bit layers. Options are FP4 and NF4 data types
|
||||
which are specified by `fp4` or `fn4`.
|
||||
bnb_4bit_use_double_quant (`bool`, *optional*, defaults to `False`):
|
||||
This flag is used for nested quantization where the quantization constants from the first quantization are
|
||||
quantized again.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
load_in_8bit=False,
|
||||
load_in_4bit=False,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_skip_modules=None,
|
||||
llm_int8_enable_fp32_cpu_offload=False,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=None,
|
||||
bnb_4bit_quant_type="fp4",
|
||||
bnb_4bit_use_double_quant=False,
|
||||
**kwargs,
|
||||
):
|
||||
self.load_in_8bit = load_in_8bit
|
||||
self.load_in_4bit = load_in_4bit
|
||||
self.llm_int8_threshold = llm_int8_threshold
|
||||
self.llm_int8_skip_modules = llm_int8_skip_modules
|
||||
self.llm_int8_enable_fp32_cpu_offload = llm_int8_enable_fp32_cpu_offload
|
||||
self.llm_int8_has_fp16_weight = llm_int8_has_fp16_weight
|
||||
self.bnb_4bit_quant_type = bnb_4bit_quant_type
|
||||
self.bnb_4bit_use_double_quant = bnb_4bit_use_double_quant
|
||||
|
||||
if bnb_4bit_compute_dtype is None:
|
||||
self.bnb_4bit_compute_dtype = torch.float32
|
||||
elif isinstance(bnb_4bit_compute_dtype, str):
|
||||
self.bnb_4bit_compute_dtype = getattr(torch, bnb_4bit_compute_dtype)
|
||||
elif isinstance(bnb_4bit_compute_dtype, torch.dtype):
|
||||
self.bnb_4bit_compute_dtype = bnb_4bit_compute_dtype
|
||||
else:
|
||||
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
|
||||
|
||||
self.post_init()
|
||||
|
||||
def post_init(self):
|
||||
r"""
|
||||
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
|
||||
"""
|
||||
if not isinstance(self.llm_int8_threshold, float):
|
||||
raise ValueError("llm_int8_threshold must be a float")
|
||||
|
||||
if self.llm_int8_skip_modules is not None and not isinstance(self.llm_int8_skip_modules, list):
|
||||
raise ValueError("llm_int8_skip_modules must be a list of strings")
|
||||
if not isinstance(self.llm_int8_enable_fp32_cpu_offload, bool):
|
||||
raise ValueError("llm_int8_enable_fp32_cpu_offload must be a boolean")
|
||||
|
||||
if not isinstance(self.llm_int8_has_fp16_weight, bool):
|
||||
raise ValueError("llm_int8_has_fp16_weight must be a boolean")
|
||||
|
||||
if self.bnb_4bit_compute_dtype is not None and not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
|
||||
raise ValueError("bnb_4bit_compute_dtype must be torch.dtype")
|
||||
|
||||
if not isinstance(self.bnb_4bit_quant_type, str):
|
||||
raise ValueError("bnb_4bit_quant_type must be a string")
|
||||
|
||||
if not isinstance(self.bnb_4bit_use_double_quant, bool):
|
||||
raise ValueError("bnb_4bit_use_double_quant must be a boolean")
|
||||
|
||||
if self.load_in_4bit and not version.parse(importlib_metadata.version("bitsandbytes")) >= version.parse(
|
||||
"0.39.0"
|
||||
):
|
||||
raise ValueError(
|
||||
"4 bit quantization requires bitsandbytes>=0.39.0 - please upgrade your bitsandbytes version"
|
||||
)
|
||||
|
||||
def is_quantizable(self):
|
||||
r"""
|
||||
Returns `True` if the model is quantizable, `False` otherwise.
|
||||
"""
|
||||
return self.load_in_8bit or self.load_in_4bit
|
||||
|
||||
def quantization_method(self):
|
||||
r"""
|
||||
This method returns the quantization method used for the model. If the model is not quantizable, it returns
|
||||
`None`.
|
||||
"""
|
||||
if self.load_in_8bit:
|
||||
return "llm_int8"
|
||||
elif self.load_in_4bit and self.bnb_4bit_quant_type == "fp4":
|
||||
return "fp4"
|
||||
elif self.load_in_4bit and self.bnb_4bit_quant_type == "nf4":
|
||||
return "nf4"
|
||||
else:
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict, return_unused_kwargs, **kwargs):
|
||||
"""
|
||||
Instantiates a [`BitsAndBytesConfig`] from a Python dictionary of parameters.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the configuration object.
|
||||
return_unused_kwargs (`bool`):
|
||||
Whether or not to return a list of unused keyword arguments. Used for `from_pretrained` method in
|
||||
`PreTrainedModel`.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
|
||||
Returns:
|
||||
[`BitsAndBytesConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default
|
||||
`BitsAndBytesConfig()` is serialized to JSON file.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
config_dict = self.to_dict()
|
||||
json_string = json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
writer.write(json_string)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["bnb_4bit_compute_dtype"] = str(output["bnb_4bit_compute_dtype"]).split(".")[1]
|
||||
|
||||
return output
|
||||
152
utils/lora.py
Normal file
152
utils/lora.py
Normal file
@@ -0,0 +1,152 @@
|
||||
# Adapted from https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/module/lora.py
|
||||
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class LinearLayer_LoRA(nn.Module):
|
||||
# a simple implementation of LoRA
|
||||
def __init__(self,
|
||||
weight,
|
||||
lora_dim=0,
|
||||
lora_scaling=1,
|
||||
lora_droppout=0,
|
||||
bias=None):
|
||||
super(LinearLayer_LoRA, self).__init__()
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
if lora_dim <= 0:
|
||||
raise ValueError(
|
||||
"You are training to use LoRA, whose reduced dim should be larger than 1"
|
||||
)
|
||||
|
||||
rows, columns = weight.shape
|
||||
self.lora_right_weight = nn.Parameter(torch.zeros(
|
||||
columns,
|
||||
lora_dim)) # apply transpose so in forward we do not need to
|
||||
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
|
||||
self.lora_scaling = lora_scaling / lora_dim
|
||||
|
||||
if lora_droppout > 0:
|
||||
self.lora_dropout = nn.Dropout(lora_droppout)
|
||||
else:
|
||||
self.lora_dropout = nn.Identity()
|
||||
|
||||
self.reset_parameters()
|
||||
# disable the original weight gradient
|
||||
self.weight.requires_grad = False
|
||||
# fuse LoRA to the original weight
|
||||
self.fuse_lora = False
|
||||
|
||||
def eval(self):
|
||||
self.lora_dropout.eval()
|
||||
|
||||
# self.fuse_lora_weight()
|
||||
|
||||
def train(self, mode=True):
|
||||
self.lora_dropout.train(mode)
|
||||
# self.unfuse_lora_weight()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_left_weight)
|
||||
|
||||
def fuse_lora_weight(self):
|
||||
if not self.fuse_lora:
|
||||
self.weight.data += self.lora_scaling * torch.matmul(
|
||||
self.lora_left_weight.t(), self.lora_right_weight.t())
|
||||
self.fuse_lora = True
|
||||
|
||||
def unfuse_lora_weight(self):
|
||||
if self.fuse_lora:
|
||||
self.weight.data -= self.lora_scaling * torch.matmul(
|
||||
self.lora_left_weight.t(), self.lora_right_weight.t())
|
||||
self.fuse_lora = False
|
||||
|
||||
def forward(self, input):
|
||||
if self.fuse_lora:
|
||||
return F.linear(input, self.weight, self.bias)
|
||||
else:
|
||||
return F.linear(
|
||||
input, self.weight,
|
||||
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
|
||||
@ self.lora_left_weight) * self.lora_scaling
|
||||
|
||||
|
||||
def recursive_getattr(model, module_name):
|
||||
"""
|
||||
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
|
||||
Recursively get the attribute of a module.
|
||||
Args:
|
||||
model (`torch.nn.Module`)
|
||||
The model to get the attribute from.
|
||||
module_name (`str`)
|
||||
The name of the module to get the attribute from.
|
||||
"""
|
||||
split_list = module_name.split('.')
|
||||
output = model
|
||||
for name in split_list:
|
||||
output = getattr(output, name)
|
||||
return output
|
||||
|
||||
|
||||
def recursive_setattr(model, module_name, module):
|
||||
"""
|
||||
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
|
||||
Recursively set the attribute of a module.
|
||||
Args:
|
||||
model (`torch.nn.Module`)
|
||||
The model to set the attribute in.
|
||||
module_name (`str`)
|
||||
The name of the module to set the attribute in.
|
||||
module (`torch.nn.Module`)
|
||||
The module to set the attribute to.
|
||||
"""
|
||||
split_list = module_name.split('.')
|
||||
output = model
|
||||
for name in split_list[:-1]:
|
||||
output = getattr(output, name)
|
||||
output.__setattr__(split_list[-1], module)
|
||||
|
||||
|
||||
# convert the linear layer to LoRA
|
||||
def convert_linear_layer_to_lora(model,
|
||||
part_module_name,
|
||||
lora_dim=0,
|
||||
lora_scaling=1,
|
||||
lora_droppout=0):
|
||||
repalce_name = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and part_module_name in name:
|
||||
repalce_name.append(name)
|
||||
for name in repalce_name:
|
||||
module = recursive_getattr(model, name)
|
||||
tmp = LinearLayer_LoRA(
|
||||
module.weight, lora_dim, lora_scaling, lora_droppout,
|
||||
module.bias).to(module.weight.device).to(module.weight.dtype)
|
||||
recursive_setattr(model, name, tmp)
|
||||
return model
|
||||
|
||||
|
||||
# convert the LoRA layer to linear layer
|
||||
def convert_lora_to_linear_layer(model):
|
||||
repalce_name = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LinearLayer_LoRA):
|
||||
repalce_name.append(name)
|
||||
for name in repalce_name:
|
||||
module = recursive_getattr(model, name)
|
||||
module.fuse_lora_weight()
|
||||
return model
|
||||
|
||||
|
||||
def only_optimize_lora_parameters(model):
|
||||
# turn off the gradient of all the parameters except the LoRA parameters
|
||||
for name, param in model.named_parameters():
|
||||
if "lora_right_weight" in name or "lora_left_weight" in name:
|
||||
param.requires_grad = True
|
||||
else:
|
||||
param.requires_grad = False
|
||||
return model
|
||||
Reference in New Issue
Block a user