diff --git a/animatediff/utils/convert_lora_safetensor_to_diffusers.py b/animatediff/utils/convert_lora_safetensor_to_diffusers.py index 0a7a429..7490e38 100644 --- a/animatediff/utils/convert_lora_safetensor_to_diffusers.py +++ b/animatediff/utils/convert_lora_safetensor_to_diffusers.py @@ -23,6 +23,32 @@ from safetensors.torch import load_file from diffusers import StableDiffusionPipeline import pdb + + +def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): + # directly update weight in diffusers model + for key in state_dict: + # only process lora down key + if "up." in key: continue + + up_key = key.replace(".down.", ".up.") + model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") + model_key = model_key.replace("to_out.", "to_out.0.") + layer_infos = model_key.split(".")[:-1] + + curr_layer = pipeline.unet + while len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + curr_layer = curr_layer.__getattr__(temp_name) + + weight_down = state_dict[key] + weight_up = state_dict[up_key] + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + + return pipeline + + + def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): # load base model # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index ee2dd2b..5393385 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -7,8 +7,11 @@ import torch import torchvision import torch.distributed as dist +from safetensors import safe_open from tqdm import tqdm from einops import rearrange +from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint +from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers def zero_rank_print(s): @@ -87,3 +90,68 @@ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt): def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""): ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) return ddim_latents + +def load_weights( + animation_pipeline, + # motion module + motion_module_path = "", + motion_module_lora_configs = [], + # image layers + dreambooth_model_path = "", + lora_model_path = "", + lora_alpha = 0.8, +): + # 1.1 motion module + unet_state_dict = {} + if motion_module_path != "": + print(f"load motion module from {motion_module_path}") + motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") + motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict + unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) + + missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) + assert len(unexpected) == 0 + del unet_state_dict + + if dreambooth_model_path != "": + print(f"load dreambooth model from {dreambooth_model_path}") + if dreambooth_model_path.endswith(".safetensors"): + dreambooth_state_dict = {} + with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + dreambooth_state_dict[key] = f.get_tensor(key) + elif dreambooth_model_path.endswith(".ckpt"): + dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") + + # 1. vae + converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) + animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) + # 2. unet + converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) + animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) + # 3. text_model + animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) + del dreambooth_state_dict + + if lora_model_path != "": + print(f"load lora model from {lora_model_path}") + assert lora_model_path.endswith(".safetensors") + lora_state_dict = {} + with safe_open(lora_model_path, framework="pt", device="cpu") as f: + for key in f.keys(): + lora_state_dict[key] = f.get_tensor(key) + + animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) + del lora_state_dict + + + for motion_module_lora_config in motion_module_lora_configs: + path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] + print(f"load motion LoRA from {path}") + + motion_lora_state_dict = torch.load(path, map_location="cpu") + motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict + + animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) + + return animation_pipeline