update file

This commit is contained in:
lzy
2023-11-10 12:37:52 +08:00
parent d6f459dbd6
commit 397ec30d15
4 changed files with 105 additions and 907 deletions

View File

@@ -7,117 +7,136 @@ from omegaconf import OmegaConf
import torch
import diffusers
from diffusers import AutoencoderKL, DDIMScheduler
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
from animatediff.models.unet import UNet3DConditionModel
from animatediff.pipelines.pipeline_animation import AnimationPipeline
from animatediff.utils.util import save_videos_grid
from animatediff.utils.util import load_weights
from animatediff.utils.util import save_videos_grid, load_weights
from diffusers.utils.import_utils import is_xformers_available
from einops import rearrange, repeat
import csv, pdb, glob
from safetensors import safe_open
import math
from pathlib import Path
import torchvision
import torchvision.transforms as transforms
from PIL import Image
import numpy as np
@torch.no_grad()
def main(args):
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
savedir = f"samples/{Path(args.config).stem}-{time_str}"
os.makedirs(savedir)
*_, func_args = inspect.getargvalues(inspect.currentframe())
func_args = dict(func_args)
time_str = datetime.datetime.now().strftime("%Y-%m-%d")
savedir = f"sample/{Path(args.exp_config).stem}_{args.H}_{args.W}-{time_str}"
os.makedirs(savedir, exist_ok=True)
# Load Config
exp_config = OmegaConf.load(args.exp_config)
config = OmegaConf.load(args.base_config)
config = OmegaConf.merge(config, exp_config)
config = OmegaConf.load(args.config)
samples = []
sample_idx = 0
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
motion_modules = model_config.motion_module
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
for motion_module in motion_modules:
inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
if config.get('base_model_path', '') != '':
args.pretrained_model_path = config.base_model_path
# Load Component
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
tokenizer_two = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer_2")
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_path, subfolder="text_encoder_2")
### >>> create validation pipeline >>> ###
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
# init unet model
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
else: assert False
# Enable memory efficient attention
if is_xformers_available() and args.xformers:
unet.enable_xformers_memory_efficient_attention()
pipeline = AnimationPipeline(
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
).to("cuda")
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, **config.noise_scheduler_kwargs)
pipeline = load_weights(
pipeline,
# motion module
motion_module_path = motion_module,
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
# image layers
dreambooth_model_path = model_config.get("dreambooth_path", ""),
lora_model_path = model_config.get("lora_model_path", ""),
lora_alpha = model_config.get("lora_alpha", 0.8),
).to("cuda")
pipeline = AnimationPipeline(
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
).to("cuda")
prompts = model_config.prompt
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
random_seeds = model_config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
config[config_key].random_seed = []
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
config[config_key].random_seed.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = model_config.steps,
guidance_scale = model_config.guidance_scale,
width = args.W,
height = args.H,
video_length = args.L,
).videos
samples.append(sample)
# Load model weights
pipeline = load_weights(
pipeline = pipeline,
motion_module_path = config.get("motion_module_path", ""),
ckpt_path = config.get("ckpt_path", ""),
lora_path = config.get("lora_path", ""),
lora_alpha = config.get("lora_alpha", 0.8)
)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
print(f"save to {savedir}/sample/{prompt}.gif")
sample_idx += 1
pipeline.unet = pipeline.unet.half()
pipeline.text_encoder = pipeline.text_encoder.half()
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
pipeline.enable_model_cpu_offload()
pipeline.enable_vae_slicing()
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
prompts = config.prompt
n_prompts = config.n_prompt
OmegaConf.save(config, f"{savedir}/config.yaml")
random_seeds = config.get("seed", [-1])
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
seeds = []
samples = []
with torch.inference_mode():
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
# manually set random seed for reproduction
if random_seed != -1: torch.manual_seed(random_seed)
else: torch.seed()
seeds.append(torch.initial_seed())
print(f"current seed: {torch.initial_seed()}")
print(f"sampling {prompt} ...")
sample = pipeline(
prompt,
negative_prompt = n_prompt,
num_inference_steps = config.get('steps', 100),
guidance_scale = config.get('guidance_scale', 10),
width = args.W,
height = args.H,
single_model_length = args.L,
).videos
samples.append(sample)
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
prompt = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# save video
save_videos_grid(sample, f"{savedir}/sample/{prompt}.mp4")
print(f"save to {savedir}/sample/{prompt}.mp4")
samples = torch.concat(samples)
save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
config.seed = seeds
OmegaConf.save(config, f"{savedir}/config.yaml")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )
parser.add_argument("--W", type=int, default=512)
parser.add_argument("--H", type=int, default=512)
parser = argparse.ArgumentParser()
args = parser.parse_args()
main(args)
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-xl-base-1.0",)
parser.add_argument("--base_config", type=str, default="configs/inference/inference.yaml")
parser.add_argument("--exp_config", type=str, required=True)
parser.add_argument("--L", type=int, default=16 )
parser.add_argument("--W", type=int, default=1024)
parser.add_argument("--H", type=int, default=1024)
parser.add_argument("--xformers", action="store_true")
args = parser.parse_args()
main(args)