mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 08:27:51 +01:00
update infer script
This commit is contained in:
@@ -15,14 +15,12 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
||||
from animatediff.utils.util import load_weights
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import csv, pdb, glob
|
||||
from safetensors import safe_open
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
@@ -60,50 +58,16 @@ def main(args):
|
||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||
).to("cuda")
|
||||
|
||||
# 1. unet ckpt
|
||||
# 1.1 motion module
|
||||
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
|
||||
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
||||
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
||||
assert len(unexpected) == 0
|
||||
|
||||
# 1.2 T2I
|
||||
if model_config.path != "":
|
||||
if model_config.path.endswith(".ckpt"):
|
||||
state_dict = torch.load(model_config.path)
|
||||
pipeline.unet.load_state_dict(state_dict)
|
||||
|
||||
elif model_config.path.endswith(".safetensors"):
|
||||
state_dict = {}
|
||||
with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key)
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
if not is_lora:
|
||||
base_state_dict = state_dict
|
||||
else:
|
||||
base_state_dict = {}
|
||||
with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
base_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# vae
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
||||
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
||||
# unet
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
||||
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||
# text_model
|
||||
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if is_lora:
|
||||
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
||||
|
||||
pipeline.to("cuda")
|
||||
### <<< create validation pipeline <<< ###
|
||||
pipeline = load_weights(
|
||||
pipeline,
|
||||
# motion module
|
||||
motion_module_path = motion_module,
|
||||
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
|
||||
# image layers
|
||||
dreambooth_model_path = model_config.get("dreambooth_path", ""),
|
||||
lora_model_path = model_config.get("lora_model_path", ""),
|
||||
lora_alpha = model_config.get("lora_alpha", 0.8),
|
||||
).to("cuda")
|
||||
|
||||
prompts = model_config.prompt
|
||||
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
||||
|
||||
Reference in New Issue
Block a user