mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 16:38:01 +01:00
add dummy key
This commit is contained in:
@@ -111,6 +111,7 @@ def load_weights(
|
||||
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
||||
unet_state_dict.pop("animatediff_config", "")
|
||||
|
||||
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||
assert len(unexpected) == 0
|
||||
@@ -154,6 +155,7 @@ def load_weights(
|
||||
print(f"load domain lora from {adapter_lora_path}")
|
||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||
domain_lora_state_dict.pop("animatediff_config", "")
|
||||
|
||||
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
|
||||
|
||||
@@ -163,6 +165,7 @@ def load_weights(
|
||||
print(f"load motion LoRA from {path}")
|
||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||
motion_lora_state_dict.pop("animatediff_config", "")
|
||||
|
||||
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
|
||||
|
||||
|
||||
@@ -69,6 +69,7 @@ def main(args):
|
||||
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
|
||||
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
|
||||
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
|
||||
controlnet_state_dict.pop("animatediff_config", "")
|
||||
controlnet.load_state_dict(controlnet_state_dict)
|
||||
controlnet.cuda()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user