mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 16:38:01 +01:00
update infer script
This commit is contained in:
@@ -23,6 +23,32 @@ from safetensors.torch import load_file
|
|||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
import pdb
|
import pdb
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
||||||
|
# directly update weight in diffusers model
|
||||||
|
for key in state_dict:
|
||||||
|
# only process lora down key
|
||||||
|
if "up." in key: continue
|
||||||
|
|
||||||
|
up_key = key.replace(".down.", ".up.")
|
||||||
|
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
||||||
|
model_key = model_key.replace("to_out.", "to_out.0.")
|
||||||
|
layer_infos = model_key.split(".")[:-1]
|
||||||
|
|
||||||
|
curr_layer = pipeline.unet
|
||||||
|
while len(layer_infos) > 0:
|
||||||
|
temp_name = layer_infos.pop(0)
|
||||||
|
curr_layer = curr_layer.__getattr__(temp_name)
|
||||||
|
|
||||||
|
weight_down = state_dict[key]
|
||||||
|
weight_up = state_dict[up_key]
|
||||||
|
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
||||||
|
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
||||||
# load base model
|
# load base model
|
||||||
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
||||||
|
|||||||
@@ -7,8 +7,11 @@ import torch
|
|||||||
import torchvision
|
import torchvision
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
from safetensors import safe_open
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
|
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||||
|
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers
|
||||||
|
|
||||||
|
|
||||||
def zero_rank_print(s):
|
def zero_rank_print(s):
|
||||||
@@ -87,3 +90,68 @@ def ddim_loop(pipeline, ddim_scheduler, latent, num_inv_steps, prompt):
|
|||||||
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt=""):
|
||||||
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt)
|
||||||
return ddim_latents
|
return ddim_latents
|
||||||
|
|
||||||
|
def load_weights(
|
||||||
|
animation_pipeline,
|
||||||
|
# motion module
|
||||||
|
motion_module_path = "",
|
||||||
|
motion_module_lora_configs = [],
|
||||||
|
# image layers
|
||||||
|
dreambooth_model_path = "",
|
||||||
|
lora_model_path = "",
|
||||||
|
lora_alpha = 0.8,
|
||||||
|
):
|
||||||
|
# 1.1 motion module
|
||||||
|
unet_state_dict = {}
|
||||||
|
if motion_module_path != "":
|
||||||
|
print(f"load motion module from {motion_module_path}")
|
||||||
|
motion_module_state_dict = torch.load(motion_module_path, map_location="cpu")
|
||||||
|
motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict
|
||||||
|
unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name})
|
||||||
|
|
||||||
|
missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False)
|
||||||
|
assert len(unexpected) == 0
|
||||||
|
del unet_state_dict
|
||||||
|
|
||||||
|
if dreambooth_model_path != "":
|
||||||
|
print(f"load dreambooth model from {dreambooth_model_path}")
|
||||||
|
if dreambooth_model_path.endswith(".safetensors"):
|
||||||
|
dreambooth_state_dict = {}
|
||||||
|
with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
dreambooth_state_dict[key] = f.get_tensor(key)
|
||||||
|
elif dreambooth_model_path.endswith(".ckpt"):
|
||||||
|
dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu")
|
||||||
|
|
||||||
|
# 1. vae
|
||||||
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config)
|
||||||
|
animation_pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
||||||
|
# 2. unet
|
||||||
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config)
|
||||||
|
animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||||
|
# 3. text_model
|
||||||
|
animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict)
|
||||||
|
del dreambooth_state_dict
|
||||||
|
|
||||||
|
if lora_model_path != "":
|
||||||
|
print(f"load lora model from {lora_model_path}")
|
||||||
|
assert lora_model_path.endswith(".safetensors")
|
||||||
|
lora_state_dict = {}
|
||||||
|
with safe_open(lora_model_path, framework="pt", device="cpu") as f:
|
||||||
|
for key in f.keys():
|
||||||
|
lora_state_dict[key] = f.get_tensor(key)
|
||||||
|
|
||||||
|
animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha)
|
||||||
|
del lora_state_dict
|
||||||
|
|
||||||
|
|
||||||
|
for motion_module_lora_config in motion_module_lora_configs:
|
||||||
|
path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"]
|
||||||
|
print(f"load motion LoRA from {path}")
|
||||||
|
|
||||||
|
motion_lora_state_dict = torch.load(path, map_location="cpu")
|
||||||
|
motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict
|
||||||
|
|
||||||
|
animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha)
|
||||||
|
|
||||||
|
return animation_pipeline
|
||||||
|
|||||||
Reference in New Issue
Block a user