mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 08:27:51 +01:00
add dummy key
This commit is contained in:
@@ -111,6 +111,7 @@ def load_weights(
|
|||||||
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||||
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||||
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
||||||
|
unet_state_dict.pop("animatediff_config", "")
|
||||||
|
|
||||||
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||||
assert len(unexpected) == 0
|
assert len(unexpected) == 0
|
||||||
@@ -154,6 +155,7 @@ def load_weights(
|
|||||||
print(f"load domain lora from {adapter_lora_path}")
|
print(f"load domain lora from {adapter_lora_path}")
|
||||||
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
|
||||||
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
|
||||||
|
domain_lora_state_dict.pop("animatediff_config", "")
|
||||||
|
|
||||||
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
|
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
|
||||||
|
|
||||||
@@ -163,6 +165,7 @@ def load_weights(
|
|||||||
print(f"load motion LoRA from {path}")
|
print(f"load motion LoRA from {path}")
|
||||||
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||||
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||||
|
motion_lora_state_dict.pop("animatediff_config", "")
|
||||||
|
|
||||||
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
|
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)
|
||||||
|
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ def main(args):
|
|||||||
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
|
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
|
||||||
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
|
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
|
||||||
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
|
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
|
||||||
|
controlnet_state_dict.pop("animatediff_config", "")
|
||||||
controlnet.load_state_dict(controlnet_state_dict)
|
controlnet.load_state_dict(controlnet_state_dict)
|
||||||
controlnet.cuda()
|
controlnet.cuda()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user