Files
AnimateDiff/scripts/animate.py
2023-11-10 12:37:52 +08:00

143 lines
4.9 KiB
Python

import argparse
import datetime
import inspect
import os
from omegaconf import OmegaConf
import torch
import diffusers
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid, load_weights
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
import csv, pdb, glob
from safetensors import safe_open
import math
from pathlib import Path
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
@torch.no_grad()
def main(args):
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
time_str = datetime.datetime.now().strftime("%Y-%m-%d")
savedir = f"sample/{Path(args.exp_config).stem}_{args.H}_{args.W}-{time_str}"
os.makedirs(savedir, exist_ok=True)
# Load Config
exp_config = OmegaConf.load(args.exp_config)
config = OmegaConf.load(args.base_config)
config = OmegaConf.merge(config, exp_config)
if config.get('base_model_path', '') != '':
args.pretrained_model_path = config.base_model_path
# Load Component
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
tokenizer_two = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer_2")
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_path, subfolder="text_encoder_2")
# init unet model
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
# Enable memory efficient attention
if is_xformers_available() and args.xformers:
unet.enable_xformers_memory_efficient_attention()
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, **config.noise_scheduler_kwargs)
pipeline = AnimationPipeline(
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
).to("cuda")
# Load model weights
pipeline = load_weights(
pipeline = pipeline,
motion_module_path = config.get("motion_module_path", ""),
ckpt_path = config.get("ckpt_path", ""),
lora_path = config.get("lora_path", ""),
lora_alpha = config.get("lora_alpha", 0.8)
)
pipeline.unet = pipeline.unet.half()
pipeline.text_encoder = pipeline.text_encoder.half()
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
prompts = config.prompt
n_prompts = config.n_prompt
random_seeds = config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
seeds = []
samples = []
with torch.inference_mode():
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
seeds.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = config.get('steps', 100),
guidance_scale = config.get('guidance_scale', 10),
width = args.W,
height = args.H,
single_model_length = args.L,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
prompt = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# save video
save_videos_grid(sample, f"{savedir}/sample/{prompt}.mp4")
print(f"save to {savedir}/sample/{prompt}.mp4")
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
config.seed = seeds
OmegaConf.save(config, f"{savedir}/config.yaml")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-xl-base-1.0",)
parser.add_argument("--base_config", type=str, default="configs/inference/inference.yaml")
parser.add_argument("--exp_config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )
parser.add_argument("--W", type=int, default=1024)
parser.add_argument("--H", type=int, default=1024)
parser.add_argument("--xformers", action="store_true")
args = parser.parse_args()
main(args)