2023-07-09 21:32:22 +08:00
|
|
|
import argparse
|
|
|
|
|
import datetime
|
|
|
|
|
import inspect
|
|
|
|
|
import os
|
|
|
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
import diffusers
|
2023-11-10 12:37:52 +08:00
|
|
|
from diffusers import AutoencoderKL, EulerDiscreteScheduler
|
2023-07-09 21:32:22 +08:00
|
|
|
|
|
|
|
|
from tqdm.auto import tqdm
|
2023-11-10 12:37:52 +08:00
|
|
|
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
2023-07-09 21:32:22 +08:00
|
|
|
|
|
|
|
|
from animatediff.models.unet import UNet3DConditionModel
|
|
|
|
|
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
2023-11-10 12:37:52 +08:00
|
|
|
from animatediff.utils.util import save_videos_grid, load_weights
|
|
|
|
|
|
2023-07-12 16:41:08 +08:00
|
|
|
from diffusers.utils.import_utils import is_xformers_available
|
2023-07-09 21:32:22 +08:00
|
|
|
|
|
|
|
|
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
|
|
import csv, pdb, glob
|
2023-11-10 12:37:52 +08:00
|
|
|
from safetensors import safe_open
|
2023-07-09 21:32:22 +08:00
|
|
|
import math
|
|
|
|
|
from pathlib import Path
|
2023-11-10 12:37:52 +08:00
|
|
|
import torchvision
|
|
|
|
|
import torchvision.transforms as transforms
|
|
|
|
|
|
|
|
|
|
from PIL import Image
|
|
|
|
|
import numpy as np
|
2023-07-09 21:32:22 +08:00
|
|
|
|
|
|
|
|
|
2023-11-10 12:37:52 +08:00
|
|
|
@torch.no_grad()
|
2023-07-09 21:32:22 +08:00
|
|
|
def main(args):
|
2023-11-10 12:37:52 +08:00
|
|
|
*_, func_args = inspect.getargvalues(inspect.currentframe())
|
|
|
|
|
func_args = dict(func_args)
|
|
|
|
|
|
|
|
|
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d")
|
|
|
|
|
|
|
|
|
|
savedir = f"sample/{Path(args.exp_config).stem}_{args.H}_{args.W}-{time_str}"
|
|
|
|
|
os.makedirs(savedir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
# Load Config
|
|
|
|
|
exp_config = OmegaConf.load(args.exp_config)
|
|
|
|
|
config = OmegaConf.load(args.base_config)
|
|
|
|
|
config = OmegaConf.merge(config, exp_config)
|
|
|
|
|
|
|
|
|
|
if config.get('base_model_path', '') != '':
|
|
|
|
|
args.pretrained_model_path = config.base_model_path
|
|
|
|
|
|
|
|
|
|
# Load Component
|
|
|
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
|
|
|
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
|
|
|
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
|
|
|
|
tokenizer_two = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer_2")
|
|
|
|
|
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_path, subfolder="text_encoder_2")
|
|
|
|
|
|
|
|
|
|
# init unet model
|
|
|
|
|
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
|
|
|
|
|
|
|
|
|
|
# Enable memory efficient attention
|
|
|
|
|
if is_xformers_available() and args.xformers:
|
|
|
|
|
unet.enable_xformers_memory_efficient_attention()
|
|
|
|
|
|
|
|
|
|
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, **config.noise_scheduler_kwargs)
|
|
|
|
|
|
|
|
|
|
pipeline = AnimationPipeline(
|
|
|
|
|
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
|
|
|
|
|
text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
|
|
|
|
|
).to("cuda")
|
|
|
|
|
|
|
|
|
|
# Load model weights
|
|
|
|
|
pipeline = load_weights(
|
|
|
|
|
pipeline = pipeline,
|
|
|
|
|
motion_module_path = config.get("motion_module_path", ""),
|
|
|
|
|
ckpt_path = config.get("ckpt_path", ""),
|
|
|
|
|
lora_path = config.get("lora_path", ""),
|
|
|
|
|
lora_alpha = config.get("lora_alpha", 0.8)
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
pipeline.unet = pipeline.unet.half()
|
|
|
|
|
pipeline.text_encoder = pipeline.text_encoder.half()
|
|
|
|
|
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
|
|
|
|
|
pipeline.enable_model_cpu_offload()
|
|
|
|
|
pipeline.enable_vae_slicing()
|
|
|
|
|
|
|
|
|
|
prompts = config.prompt
|
|
|
|
|
n_prompts = config.n_prompt
|
|
|
|
|
|
|
|
|
|
random_seeds = config.get("seed", [-1])
|
|
|
|
|
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
|
|
|
|
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
|
|
|
|
seeds = []
|
|
|
|
|
samples = []
|
|
|
|
|
|
|
|
|
|
with torch.inference_mode():
|
|
|
|
|
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
|
|
|
|
# manually set random seed for reproduction
|
|
|
|
|
if random_seed != -1: torch.manual_seed(random_seed)
|
|
|
|
|
else: torch.seed()
|
|
|
|
|
seeds.append(torch.initial_seed())
|
|
|
|
|
print(f"current seed: {torch.initial_seed()}")
|
|
|
|
|
print(f"sampling {prompt} ...")
|
|
|
|
|
sample = pipeline(
|
|
|
|
|
prompt,
|
|
|
|
|
negative_prompt = n_prompt,
|
|
|
|
|
num_inference_steps = config.get('steps', 100),
|
|
|
|
|
guidance_scale = config.get('guidance_scale', 10),
|
|
|
|
|
width = args.W,
|
|
|
|
|
height = args.H,
|
|
|
|
|
single_model_length = args.L,
|
|
|
|
|
).videos
|
|
|
|
|
samples.append(sample)
|
|
|
|
|
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
|
|
|
|
prompt = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
|
|
|
|
|
|
|
|
|
# save video
|
|
|
|
|
save_videos_grid(sample, f"{savedir}/sample/{prompt}.mp4")
|
|
|
|
|
print(f"save to {savedir}/sample/{prompt}.mp4")
|
|
|
|
|
|
|
|
|
|
samples = torch.concat(samples)
|
|
|
|
|
save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
|
|
|
|
|
config.seed = seeds
|
|
|
|
|
OmegaConf.save(config, f"{savedir}/config.yaml")
|
2023-07-09 21:32:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
2023-11-10 12:37:52 +08:00
|
|
|
parser = argparse.ArgumentParser()
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-xl-base-1.0",)
|
|
|
|
|
parser.add_argument("--base_config", type=str, default="configs/inference/inference.yaml")
|
|
|
|
|
parser.add_argument("--exp_config", type=str, required=True)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--L", type=int, default=16 )
|
|
|
|
|
parser.add_argument("--W", type=int, default=1024)
|
|
|
|
|
parser.add_argument("--H", type=int, default=1024)
|
|
|
|
|
|
|
|
|
|
parser.add_argument("--xformers", action="store_true")
|
|
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
main(args)
|