mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2026-04-03 09:46:36 +02:00
support sparsectrl
This commit is contained in:
@@ -11,7 +11,7 @@ from safetensors import safe_open
|
||||
from tqdm import tqdm
|
||||
from einops import rearrange
|
||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, load_diffusers_lora
|
||||
|
||||
|
||||
def zero_rank_print(s):
|
||||
@@ -96,12 +96,15 @@ def load_weights(
|
||||
# motion module
|
||||
motion_module_path = "",
|
||||
motion_module_lora_configs = [],
|
||||
# domain adapter
|
||||
adapter_lora_path = "",
|
||||
adapter_lora_scale = 1.0,
|
||||
# image layers
|
||||
dreambooth_model_path = "",
|
||||
lora_model_path = "",
|
||||
lora_alpha = 0.8,
|
||||
dreambooth_model_path = "",
|
||||
lora_model_path = "",
|
||||
lora_alpha = 0.8,
|
||||
):
|
||||
# 1.1 motion module
|
||||
# motion module
|
||||
unet_state_dict = {}
|
||||
if motion_module_path != "":
|
||||
print(f"load motion module from {motion_module_path}")
|
||||
@@ -113,6 +116,7 @@ def load_weights(
|
||||
assert len(unexpected) == 0
|
||||
del unet_state_dict
|
||||
|
||||
# base model
|
||||
if dreambooth_model_path != "":
|
||||
print(f"load dreambooth model from {dreambooth_model_path}")
|
||||
if dreambooth_model_path.endswith(".safetensors"):
|
||||
@@ -133,6 +137,7 @@ def load_weights(
|
||||
animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
||||
del dreambooth_state_dict
|
||||
|
||||
# lora layers
|
||||
if lora_model_path != "":
|
||||
print(f"load lora model from {lora_model_path}")
|
||||
assert lora_model_path.endswith(".safetensors")
|
||||
@@ -144,14 +149,21 @@ def load_weights(
|
||||
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
|
||||
del lora_state_dict
|
||||
|
||||
# domain adapter lora
|
||||
if adapter_lora_path != "":
|
||||
print(f"load domain lora from {adapter_lora_path}")
|
||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||
|
||||
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
|
||||
|
||||
# motion module lora
|
||||
for motion_module_lora_config in motion_module_lora_configs:
|
||||
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
||||
print(f"load motion LoRA from {path}")
|
||||
|
||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||
|
||||
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
|
||||
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
|
||||
|
||||
return animation_pipeline
|
||||
|
||||
Reference in New Issue
Block a user