mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-19 09:49:34 +01:00
add code
This commit is contained in:
146
scripts/animate.py
Normal file
146
scripts/animate.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import argparse
|
||||
import datetime
|
||||
import inspect
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import csv, pdb, glob
|
||||
from safetensors import safe_open
|
||||
import math
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main(args):
|
||||
*_, func_args = inspect.getargvalues(inspect.currentframe())
|
||||
func_args = dict(func_args)
|
||||
|
||||
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
savedir = f"samples/{Path(args.config).stem}-{time_str}"
|
||||
os.makedirs(savedir)
|
||||
inference_config = OmegaConf.load(args.inference_config)
|
||||
|
||||
config = OmegaConf.load(args.config)
|
||||
samples = []
|
||||
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
||||
### >>> create validation pipeline >>> ###
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
||||
|
||||
pipeline = AnimationPipeline(
|
||||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||
).to("cuda")
|
||||
|
||||
# 1. unet ckpt
|
||||
# 1.1 motion module
|
||||
motion_module_state_dict = torch.load(model_config.motion_module, map_location="cpu")
|
||||
if "global_step" in motion_module_state_dict: func_args.update({"global_step": motion_module_state_dict["global_step"]})
|
||||
missing, unexpected = pipeline.unet.load_state_dict(motion_module_state_dict, strict=False)
|
||||
assert len(unexpected) == 0
|
||||
|
||||
# 1.2 T2I
|
||||
if model_config.path != "":
|
||||
if model_config.path.endswith(".ckpt"):
|
||||
state_dict = torch.load(model_config.path)
|
||||
pipeline.unet.load_state_dict(state_dict)
|
||||
|
||||
elif model_config.path.endswith(".safetensors"):
|
||||
state_dict = {}
|
||||
with safe_open(model_config.path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
state_dict[key] = f.get_tensor(key)
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
if not is_lora:
|
||||
base_state_dict = state_dict
|
||||
else:
|
||||
base_state_dict = {}
|
||||
with safe_open(model_config.base, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
base_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
# vae
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_state_dict, pipeline.vae.config)
|
||||
pipeline.vae.load_state_dict(converted_vae_checkpoint)
|
||||
# unet
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_state_dict, pipeline.unet.config)
|
||||
pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||
# text_model
|
||||
pipeline.text_encoder = convert_ldm_clip_checkpoint(base_state_dict)
|
||||
|
||||
# import pdb
|
||||
# pdb.set_trace()
|
||||
if is_lora:
|
||||
pipeline = convert_lora(pipeline, state_dict, alpha=model_config.lora_alpha)
|
||||
|
||||
pipeline.to("cuda")
|
||||
### <<< create validation pipeline <<< ###
|
||||
|
||||
prompts = model_config.prompt
|
||||
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
||||
|
||||
random_seeds = model_config.pop("seed", [-1])
|
||||
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
||||
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
||||
|
||||
config[config_key].random_seed = []
|
||||
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
||||
|
||||
# manually set random seed for reproduction
|
||||
if random_seed != -1: torch.manual_seed(random_seed)
|
||||
else: torch.seed()
|
||||
config[config_key].random_seed.append(torch.initial_seed())
|
||||
|
||||
print(f"current seed: {torch.initial_seed()}")
|
||||
print(f"sampling {prompt} ...")
|
||||
sample = pipeline(
|
||||
prompt,
|
||||
negative_prompt = n_prompt,
|
||||
num_inference_steps = model_config.steps,
|
||||
guidance_scale = model_config.guidance_scale,
|
||||
width = args.W,
|
||||
height = args.H,
|
||||
video_length = args.L,
|
||||
).videos
|
||||
samples.append(sample)
|
||||
|
||||
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
||||
save_videos_grid(sample, f"{savedir}/sample/{model_idx}-{prompt_idx}-{prompt}.gif")
|
||||
print(f"save to {savedir}/sample/{prompt}.gif")
|
||||
|
||||
samples = torch.concat(samples)
|
||||
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|
||||
|
||||
OmegaConf.save(config, f"{savedir}/config.yaml")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/inference.yaml")
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
|
||||
parser.add_argument("--L", type=int, default=16 )
|
||||
parser.add_argument("--W", type=int, default=512)
|
||||
parser.add_argument("--H", type=int, default=512)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user