diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index 924bbb7..c094483 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -111,6 +111,7 @@ def load_weights( motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) + unet_state_dict.pop("animatediff_config", "") missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) assert len(unexpected) == 0 @@ -154,6 +155,7 @@ def load_weights( print(f"load domain lora from {adapter_lora_path}") domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu") domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict + domain_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale) @@ -163,6 +165,7 @@ def load_weights( print(f"load motion LoRA from {path}") motion_lora_state_dict = torch.load(path, map_location="cpu") motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict + motion_lora_state_dict.pop("animatediff_config", "") animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha) diff --git a/scripts/animate.py b/scripts/animate.py index 3a84dc1..22f4c2a 100644 --- a/scripts/animate.py +++ b/scripts/animate.py @@ -69,6 +69,7 @@ def main(args): print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...") controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu") controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict + controlnet_state_dict.pop("animatediff_config", "") controlnet.load_state_dict(controlnet_state_dict) controlnet.cuda()