mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-15 19:27:57 +01:00
152 lines
5.1 KiB
Python
152 lines
5.1 KiB
Python
# Adapted from https://github.com/microsoft/DeepSpeedExamples/blob/master/applications/DeepSpeed-Chat/training/utils/module/lora.py
|
|
|
|
import math
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
|
|
class LinearLayer_LoRA(nn.Module):
|
|
# a simple implementation of LoRA
|
|
def __init__(self,
|
|
weight,
|
|
lora_dim=0,
|
|
lora_scaling=1,
|
|
lora_droppout=0,
|
|
bias=None):
|
|
super(LinearLayer_LoRA, self).__init__()
|
|
self.weight = weight
|
|
self.bias = bias
|
|
|
|
if lora_dim <= 0:
|
|
raise ValueError(
|
|
"You are training to use LoRA, whose reduced dim should be larger than 1"
|
|
)
|
|
|
|
rows, columns = weight.shape
|
|
self.lora_right_weight = nn.Parameter(torch.zeros(
|
|
columns,
|
|
lora_dim)) # apply transpose so in forward we do not need to
|
|
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
|
|
self.lora_scaling = lora_scaling / lora_dim
|
|
|
|
if lora_droppout > 0:
|
|
self.lora_dropout = nn.Dropout(lora_droppout)
|
|
else:
|
|
self.lora_dropout = nn.Identity()
|
|
|
|
self.reset_parameters()
|
|
# disable the original weight gradient
|
|
self.weight.requires_grad = False
|
|
# fuse LoRA to the original weight
|
|
self.fuse_lora = False
|
|
|
|
def eval(self):
|
|
self.lora_dropout.eval()
|
|
|
|
# self.fuse_lora_weight()
|
|
|
|
def train(self, mode=True):
|
|
self.lora_dropout.train(mode)
|
|
# self.unfuse_lora_weight()
|
|
|
|
def reset_parameters(self):
|
|
nn.init.kaiming_uniform_(self.lora_right_weight, a=math.sqrt(5))
|
|
nn.init.zeros_(self.lora_left_weight)
|
|
|
|
def fuse_lora_weight(self):
|
|
if not self.fuse_lora:
|
|
self.weight.data += self.lora_scaling * torch.matmul(
|
|
self.lora_left_weight.t(), self.lora_right_weight.t())
|
|
self.fuse_lora = True
|
|
|
|
def unfuse_lora_weight(self):
|
|
if self.fuse_lora:
|
|
self.weight.data -= self.lora_scaling * torch.matmul(
|
|
self.lora_left_weight.t(), self.lora_right_weight.t())
|
|
self.fuse_lora = False
|
|
|
|
def forward(self, input):
|
|
if self.fuse_lora:
|
|
return F.linear(input, self.weight, self.bias)
|
|
else:
|
|
return F.linear(
|
|
input, self.weight,
|
|
self.bias) + (self.lora_dropout(input) @ self.lora_right_weight
|
|
@ self.lora_left_weight) * self.lora_scaling
|
|
|
|
|
|
def recursive_getattr(model, module_name):
|
|
"""
|
|
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
|
|
Recursively get the attribute of a module.
|
|
Args:
|
|
model (`torch.nn.Module`)
|
|
The model to get the attribute from.
|
|
module_name (`str`)
|
|
The name of the module to get the attribute from.
|
|
"""
|
|
split_list = module_name.split('.')
|
|
output = model
|
|
for name in split_list:
|
|
output = getattr(output, name)
|
|
return output
|
|
|
|
|
|
def recursive_setattr(model, module_name, module):
|
|
"""
|
|
From https://github.com/microsoft/DeepSpeed/blob/master/deepspeed/compression/helper.py
|
|
Recursively set the attribute of a module.
|
|
Args:
|
|
model (`torch.nn.Module`)
|
|
The model to set the attribute in.
|
|
module_name (`str`)
|
|
The name of the module to set the attribute in.
|
|
module (`torch.nn.Module`)
|
|
The module to set the attribute to.
|
|
"""
|
|
split_list = module_name.split('.')
|
|
output = model
|
|
for name in split_list[:-1]:
|
|
output = getattr(output, name)
|
|
output.__setattr__(split_list[-1], module)
|
|
|
|
|
|
# convert the linear layer to LoRA
|
|
def convert_linear_layer_to_lora(model,
|
|
part_module_name,
|
|
lora_dim=0,
|
|
lora_scaling=1,
|
|
lora_droppout=0):
|
|
repalce_name = []
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, nn.Linear) and part_module_name in name:
|
|
repalce_name.append(name)
|
|
for name in repalce_name:
|
|
module = recursive_getattr(model, name)
|
|
tmp = LinearLayer_LoRA(
|
|
module.weight, lora_dim, lora_scaling, lora_droppout,
|
|
module.bias).to(module.weight.device).to(module.weight.dtype)
|
|
recursive_setattr(model, name, tmp)
|
|
return model
|
|
|
|
|
|
# convert the LoRA layer to linear layer
|
|
def convert_lora_to_linear_layer(model):
|
|
repalce_name = []
|
|
for name, module in model.named_modules():
|
|
if isinstance(module, LinearLayer_LoRA):
|
|
repalce_name.append(name)
|
|
for name in repalce_name:
|
|
module = recursive_getattr(model, name)
|
|
module.fuse_lora_weight()
|
|
return model
|
|
|
|
|
|
def only_optimize_lora_parameters(model):
|
|
# turn off the gradient of all the parameters except the LoRA parameters
|
|
for name, param in model.named_parameters():
|
|
if "lora_right_weight" in name or "lora_left_weight" in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
return model |