Files
bark-with-voice-clone/utils/lora.py
2023-06-29 21:48:18 -06:00

152 lines
5.1 KiB
Python

# Adapted from https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/module/lora.py
import math
import torch
from torch import nn
import torch.nn.functional as F
class LinearLayer_LoRA(nn.Module):
# a simple implementation of LoRA
def __init__(self,
weight,
lora_dim=0,
lora_scaling=1,
lora_dropout=0,
bias=None):
super(LinearLayer_LoRA, self).__init__()
self.weight = weight
self.bias = bias
if lora_dim <= 0:
raise ValueError(
"You are training to use LoRA, whose reduced dim should be larger than 1"
)
rows, columns = weight.shape
self.lora_right_weight = nn.Parameter(torch.zeros(
columns,
lora_dim)) # apply transpose so in forward we do not need to
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
self.lora_scaling = lora_scaling / lora_dim
if lora_dropout > 0:
self.lora_dropout = nn.Dropout(lora_dropout)
else:
self.lora_dropout = nn.Identity()
self.reset_parameters()
# disable the original weight gradient
self.weight.requires_grad = False
# fuse LoRA to the original weight
self.fuse_lora = False
def eval(self):
self.lora_dropout.eval()
# self.fuse_lora_weight()
def train(self, mode=True):
self.lora_dropout.train(mode)
# self.unfuse_lora_weight()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
nn.init.zeros_(self.lora_left_weight)
def fuse_lora_weight(self):
if not self.fuse_lora:
self.weight.data += self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = True
def unfuse_lora_weight(self):
if self.fuse_lora:
self.weight.data -= self.lora_scaling * torch.matmul(
self.lora_left_weight.t(), self.lora_right_weight.t())
self.fuse_lora = False
def forward(self, input):
if self.fuse_lora:
return F.linear(input, self.weight, self.bias)
else:
return F.linear(
input, self.weight,
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
@ self.lora_left_weight) * self.lora_scaling
def recursive_getattr(model, module_name):
"""
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
Recursively get the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to get the attribute from.
module_name (`str`)
The name of the module to get the attribute from.
"""
split_list = module_name.split('.')
output = model
for name in split_list:
output = getattr(output, name)
return output
def recursive_setattr(model, module_name, module):
"""
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
Recursively set the attribute of a module.
Args:
model (`torch.nn.Module`)
The model to set the attribute in.
module_name (`str`)
The name of the module to set the attribute in.
module (`torch.nn.Module`)
The module to set the attribute to.
"""
split_list = module_name.split('.')
output = model
for name in split_list[:-1]:
output = getattr(output, name)
output.__setattr__(split_list[-1], module)
# convert the linear layer to LoRA
def convert_linear_layer_to_lora(model,
part_module_name,
lora_dim=0,
lora_scaling=1,
lora_dropout=0):
replace_name = []
for name, module in model.named_modules():
if isinstance(module, nn.Linear) and part_module_name in name:
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
tmp = LinearLayer_LoRA(
module.weight, lora_dim, lora_scaling, lora_dropout,
module.bias).to(module.weight.device).to(module.weight.dtype)
recursive_setattr(model, name, tmp)
return model
# convert the LoRA layer to linear layer
def convert_lora_to_linear_layer(model):
replace_name = []
for name, module in model.named_modules():
if isinstance(module, LinearLayer_LoRA):
replace_name.append(name)
for name in replace_name:
module = recursive_getattr(model, name)
module.fuse_lora_weight()
return model
def only_optimize_lora_parameters(model):
# turn off the gradient of all the parameters except the LoRA parameters
for name, param in model.named_parameters():
if "lora_right_weight" in name or "lora_left_weight" in name:
param.requires_grad = True
else:
param.requires_grad = False
return model