mirror of
https://github.com/serp-ai/bark-with-voice-clone.git
synced 2025-12-16 03:38:01 +01:00
Add v1 finetune support
This commit is contained in:
@@ -11,7 +11,7 @@ class LinearLayer_LoRA(nn.Module):
|
||||
weight,
|
||||
lora_dim=0,
|
||||
lora_scaling=1,
|
||||
lora_droppout=0,
|
||||
lora_dropout=0,
|
||||
bias=None):
|
||||
super(LinearLayer_LoRA, self).__init__()
|
||||
self.weight = weight
|
||||
@@ -29,8 +29,8 @@ class LinearLayer_LoRA(nn.Module):
|
||||
self.lora_left_weight = nn.Parameter(torch.zeros(lora_dim, rows))
|
||||
self.lora_scaling = lora_scaling / lora_dim
|
||||
|
||||
if lora_droppout > 0:
|
||||
self.lora_dropout = nn.Dropout(lora_droppout)
|
||||
if lora_dropout > 0:
|
||||
self.lora_dropout = nn.Dropout(lora_dropout)
|
||||
else:
|
||||
self.lora_dropout = nn.Identity()
|
||||
|
||||
@@ -116,15 +116,15 @@ def convert_linear_layer_to_lora(model,
|
||||
part_module_name,
|
||||
lora_dim=0,
|
||||
lora_scaling=1,
|
||||
lora_droppout=0):
|
||||
repalce_name = []
|
||||
lora_dropout=0):
|
||||
replace_name = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear) and part_module_name in name:
|
||||
repalce_name.append(name)
|
||||
for name in repalce_name:
|
||||
replace_name.append(name)
|
||||
for name in replace_name:
|
||||
module = recursive_getattr(model, name)
|
||||
tmp = LinearLayer_LoRA(
|
||||
module.weight, lora_dim, lora_scaling, lora_droppout,
|
||||
module.weight, lora_dim, lora_scaling, lora_dropout,
|
||||
module.bias).to(module.weight.device).to(module.weight.dtype)
|
||||
recursive_setattr(model, name, tmp)
|
||||
return model
|
||||
@@ -132,11 +132,11 @@ def convert_linear_layer_to_lora(model,
|
||||
|
||||
# convert the LoRA layer to linear layer
|
||||
def convert_lora_to_linear_layer(model):
|
||||
repalce_name = []
|
||||
replace_name = []
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LinearLayer_LoRA):
|
||||
repalce_name.append(name)
|
||||
for name in repalce_name:
|
||||
replace_name.append(name)
|
||||
for name in replace_name:
|
||||
module = recursive_getattr(model, name)
|
||||
module.fuse_lora_weight()
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user