update infer script

This commit is contained in:
Yuwei Guo
2023-09-25 11:38:29 +08:00
parent 4498f762c5
commit d09e5cfa53

View File

@@ -15,14 +15,12 @@ from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
from animatediff.utils.util import load_weights
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
import csv, pdb, glob
from safetensors import safe_open
import math
from pathlib import Path
@@ -60,50 +58,16 @@ def main(args):
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
# 1. unet ckpt
# 1.1 motion module
motion_module_state_dict = torch.load(motion_module, map_location="cpu")
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
assert len(unexpected) == 0
# 1.2 T2I
if model_config.path != "":
if model_config.path.endswith(".ckpt"):
state_dict = torch.load(model_config.path)
pipeline.unet.load_state_dict(state_dict)
elif model_config.path.endswith(".safetensors"):
state_dict = {}
with safe_open(model_config.path, framework="pt", device="cpu") as f:
for key in f.keys():
state_dict[key] = f.get_tensor(key)
is_lora = all("lora" in k for k in state_dict.keys())
if not is_lora:
base_state_dict = state_dict
else:
base_state_dict = {}
with safe_open(model_config.base, framework="pt", device="cpu") as f:
for key in f.keys():
base_state_dict[key] = f.get_tensor(key)
# vae
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
pipeline.vae.load_state_dict(converted_vae_checkpoint)
# unet
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
# text_model
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
# import pdb
# pdb.set_trace()
if is_lora:
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
pipeline.to("cuda")
### <<< create validation pipeline <<< ###
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = motion_module,
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
# image layers
dreambooth_model_path = model_config.get("dreambooth_path", ""),
lora_model_path = model_config.get("lora_model_path", ""),
lora_alpha = model_config.get("lora_alpha", 0.8),
).to("cuda")
prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt