add dummy key

This commit is contained in:
Yuwei Guo
2023-12-16 12:50:16 +08:00
parent 57e7d14ede
commit 0e9ad276e7
2 changed files with 4 additions and 0 deletions

View File

@@ -111,6 +111,7 @@ def load_weights(
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
unet_state_dict.pop("animatediff_config", "")
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
assert len(unexpected) == 0
@@ -154,6 +155,7 @@ def load_weights(
print(f"load domain lora from {adapter_lora_path}")
domain_lora_state_dict = torch.load(adapter_lora_path, map_location="cpu")
domain_lora_state_dict = domain_lora_state_dict["state_dict"] if "state_dict" in domain_lora_state_dict else domain_lora_state_dict
domain_lora_state_dict.pop("animatediff_config", "")
animation_pipeline = load_diffusers_lora(animation_pipeline, domain_lora_state_dict, alpha=adapter_lora_scale)
@@ -163,6 +165,7 @@ def load_weights(
print(f"load motion LoRA from {path}")
motion_lora_state_dict = torch.load(path, map_location="cpu")
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
motion_lora_state_dict.pop("animatediff_config", "")
animation_pipeline = load_diffusers_lora(animation_pipeline, motion_lora_state_dict, alpha)

View File

@@ -69,6 +69,7 @@ def main(args):
print(f"loading controlnet checkpoint from {model_config.controlnet_path} ...")
controlnet_state_dict = torch.load(model_config.controlnet_path, map_location="cpu")
controlnet_state_dict = controlnet_state_dict["controlnet"] if "controlnet" in controlnet_state_dict else controlnet_state_dict
controlnet_state_dict.pop("animatediff_config", "")
controlnet.load_state_dict(controlnet_state_dict)
controlnet.cuda()