This commit is contained in:
XDUWQ
2023-08-31 16:59:18 +08:00
parent a9de26f683
commit 7715050047

View File

@@ -114,11 +114,10 @@ class EfficientStableDiffusion(TorchModel):
'rank'] if tuner_config and 'rank' in tuner_config else 4
lora_config = LoRAConfig(
r=rank,
replace_modules=['to_q', 'to_k', 'to_v', 'to_out.0'],
target_modules=['to_q', 'to_k', 'to_v', 'to_out.0'],
merge_weights=False,
only_lora_trainable=False,
use_merged_linear=False,
pretrained_weights=pretrained_tuner)
use_merged_linear=False)
self.unet = Swift.prepare_model(self.unet, lora_config)
elif tuner_name == 'swift-adapter':
adapter_length = tuner_config[
@@ -127,9 +126,7 @@ class EfficientStableDiffusion(TorchModel):
dim=-1,
hidden_pos=0,
target_modules=r'.*ff\.net\.2$',
adapter_length=adapter_length,
only_adapter_trainable=False,
pretrained_weights=pretrained_tuner)
adapter_length=adapter_length)
self.unet = Swift.prepare_model(self.unet, adapter_config)
elif tuner_name == 'swift-prompt':
prompt_length = tuner_config[