Files
AnimateDiff/scripts/animate.py

124 lines
5.3 KiB
Python
Raw Normal View History

2023-07-09 21:32:22 +08:00
import argparse
import datetime
import inspect
import os
from omegaconf import OmegaConf
import torch
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
2023-09-25 11:38:29 +08:00
from animatediff.utils.util import load_weights
2023-07-12 16:41:08 +08:00
from diffusers.utils.import_utils import is_xformers_available
2023-07-09 21:32:22 +08:00
from einops import rearrange, repeat
import csv, pdb, glob
import math
from pathlib import Path
def main(args):
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
os.makedirs(savedir)
config = OmegaConf.load(args.config)
samples = []
2023-07-09 23:25:46 +08:00
sample_idx = 0
2023-07-09 21:32:22 +08:00
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
2023-07-09 23:25:46 +08:00
motion_modules = model_config.motion_module
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
for motion_module in motion_modules:
2023-09-10 21:27:27 +08:00
inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
2023-07-09 23:25:46 +08:00
### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
2023-07-12 16:41:08 +08:00
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
else: assert False
2023-07-09 23:25:46 +08:00
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
2023-09-25 11:38:29 +08:00
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = motion_module,
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
# image layers
dreambooth_model_path = model_config.get("dreambooth_path", ""),
lora_model_path = model_config.get("lora_model_path", ""),
lora_alpha = model_config.get("lora_alpha", 0.8),
).to("cuda")
2023-07-09 23:25:46 +08:00
prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
2023-07-09 21:32:22 +08:00
2023-07-09 23:25:46 +08:00
random_seeds = model_config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
2023-07-09 21:32:22 +08:00
2023-07-09 23:25:46 +08:00
config[config_key].random_seed = []
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
config[config_key].random_seed.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = model_config.steps,
guidance_scale = model_config.guidance_scale,
width = args.W,
height = args.H,
video_length = args.L,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
print(f"save to {savedir}/sample/{prompt}.gif")
sample_idx += 1
2023-07-09 21:32:22 +08:00
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
OmegaConf.save(config, f"{savedir}/config.yaml")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
2023-09-10 21:27:27 +08:00
parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
2023-07-09 21:32:22 +08:00
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )
parser.add_argument("--W", type=int, default=512)
parser.add_argument("--H", type=int, default=512)
args = parser.parse_args()
main(args)