mirror of
https://github.com/guoyww/AnimateDiff.git
synced 2025-12-16 16:38:01 +01:00
update file
This commit is contained in:
BIN
animatediff/.DS_Store
vendored
BIN
animatediff/.DS_Store
vendored
Binary file not shown.
328
app.py
328
app.py
@@ -1,328 +0,0 @@
|
||||
|
||||
import os
|
||||
import json
|
||||
import torch
|
||||
import random
|
||||
|
||||
import gradio as gr
|
||||
from glob import glob
|
||||
from omegaconf import OmegaConf
|
||||
from datetime import datetime
|
||||
from safetensors import safe_open
|
||||
|
||||
from diffusers import AutoencoderKL
|
||||
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
||||
from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
|
||||
|
||||
|
||||
sample_idx = 0
|
||||
scheduler_dict = {
|
||||
"Euler": EulerDiscreteScheduler,
|
||||
"PNDM": PNDMScheduler,
|
||||
"DDIM": DDIMScheduler,
|
||||
}
|
||||
|
||||
css = """
|
||||
.toolbutton {
|
||||
margin-buttom: 0em 0em 0em 0em;
|
||||
max-width: 2.5em;
|
||||
min-width: 2.5em !important;
|
||||
height: 2.5em;
|
||||
}
|
||||
"""
|
||||
|
||||
class AnimateController:
|
||||
def __init__(self):
|
||||
|
||||
# config dirs
|
||||
self.basedir = os.getcwd()
|
||||
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
|
||||
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
|
||||
self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
|
||||
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
||||
self.savedir_sample = os.path.join(self.savedir, "sample")
|
||||
os.makedirs(self.savedir, exist_ok=True)
|
||||
|
||||
self.stable_diffusion_list = []
|
||||
self.motion_module_list = []
|
||||
self.personalized_model_list = []
|
||||
|
||||
self.refresh_stable_diffusion()
|
||||
self.refresh_motion_module()
|
||||
self.refresh_personalized_model()
|
||||
|
||||
# config models
|
||||
self.tokenizer = None
|
||||
self.text_encoder = None
|
||||
self.vae = None
|
||||
self.unet = None
|
||||
self.pipeline = None
|
||||
self.lora_model_state_dict = {}
|
||||
|
||||
self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
|
||||
|
||||
def refresh_stable_diffusion(self):
|
||||
self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
|
||||
|
||||
def refresh_motion_module(self):
|
||||
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
|
||||
self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
|
||||
|
||||
def refresh_personalized_model(self):
|
||||
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
|
||||
self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
|
||||
|
||||
def update_stable_diffusion(self, stable_diffusion_dropdown):
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
|
||||
self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
|
||||
self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
|
||||
self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
|
||||
return gr.Dropdown.update()
|
||||
|
||||
def update_motion_module(self, motion_module_dropdown):
|
||||
if self.unet is None:
|
||||
gr.Info(f"Please select a pretrained model path.")
|
||||
return gr.Dropdown.update(value=None)
|
||||
else:
|
||||
motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
|
||||
motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
|
||||
missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
|
||||
assert len(unexpected) == 0
|
||||
return gr.Dropdown.update()
|
||||
|
||||
def update_base_model(self, base_model_dropdown):
|
||||
if self.unet is None:
|
||||
gr.Info(f"Please select a pretrained model path.")
|
||||
return gr.Dropdown.update(value=None)
|
||||
else:
|
||||
base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
|
||||
base_model_state_dict = {}
|
||||
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
base_model_state_dict[key] = f.get_tensor(key)
|
||||
|
||||
converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
|
||||
self.vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
|
||||
self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
|
||||
|
||||
self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
|
||||
return gr.Dropdown.update()
|
||||
|
||||
def update_lora_model(self, lora_model_dropdown):
|
||||
lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
|
||||
self.lora_model_state_dict = {}
|
||||
if lora_model_dropdown == "none": pass
|
||||
else:
|
||||
with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
self.lora_model_state_dict[key] = f.get_tensor(key)
|
||||
return gr.Dropdown.update()
|
||||
|
||||
def animate(
|
||||
self,
|
||||
stable_diffusion_dropdown,
|
||||
motion_module_dropdown,
|
||||
base_model_dropdown,
|
||||
lora_alpha_slider,
|
||||
prompt_textbox,
|
||||
negative_prompt_textbox,
|
||||
sampler_dropdown,
|
||||
sample_step_slider,
|
||||
width_slider,
|
||||
length_slider,
|
||||
height_slider,
|
||||
cfg_scale_slider,
|
||||
seed_textbox
|
||||
):
|
||||
if self.unet is None:
|
||||
raise gr.Error(f"Please select a pretrained model path.")
|
||||
if motion_module_dropdown == "":
|
||||
raise gr.Error(f"Please select a motion module.")
|
||||
if base_model_dropdown == "":
|
||||
raise gr.Error(f"Please select a base DreamBooth model.")
|
||||
|
||||
if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
pipeline = AnimationPipeline(
|
||||
vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
|
||||
scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
||||
).to("cuda")
|
||||
|
||||
if self.lora_model_state_dict != {}:
|
||||
pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
|
||||
|
||||
pipeline.to("cuda")
|
||||
|
||||
if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
|
||||
else: torch.seed()
|
||||
seed = torch.initial_seed()
|
||||
|
||||
sample = pipeline(
|
||||
prompt_textbox,
|
||||
negative_prompt = negative_prompt_textbox,
|
||||
num_inference_steps = sample_step_slider,
|
||||
guidance_scale = cfg_scale_slider,
|
||||
width = width_slider,
|
||||
height = height_slider,
|
||||
video_length = length_slider,
|
||||
).videos
|
||||
|
||||
save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
|
||||
save_videos_grid(sample, save_sample_path)
|
||||
|
||||
sample_config = {
|
||||
"prompt": prompt_textbox,
|
||||
"n_prompt": negative_prompt_textbox,
|
||||
"sampler": sampler_dropdown,
|
||||
"num_inference_steps": sample_step_slider,
|
||||
"guidance_scale": cfg_scale_slider,
|
||||
"width": width_slider,
|
||||
"height": height_slider,
|
||||
"video_length": length_slider,
|
||||
"seed": seed
|
||||
}
|
||||
json_str = json.dumps(sample_config, indent=4)
|
||||
with open(os.path.join(self.savedir, "logs.json"), "a") as f:
|
||||
f.write(json_str)
|
||||
f.write("\n\n")
|
||||
|
||||
return gr.Video.update(value=save_sample_path)
|
||||
|
||||
|
||||
controller = AnimateController()
|
||||
|
||||
|
||||
def ui():
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(
|
||||
"""
|
||||
# [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
|
||||
Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
|
||||
[Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
|
||||
"""
|
||||
)
|
||||
with gr.Column(variant="panel"):
|
||||
gr.Markdown(
|
||||
"""
|
||||
### 1. Model checkpoints (select pretrained model path first).
|
||||
"""
|
||||
)
|
||||
with gr.Row():
|
||||
stable_diffusion_dropdown = gr.Dropdown(
|
||||
label="Pretrained Model Path",
|
||||
choices=controller.stable_diffusion_list,
|
||||
interactive=True,
|
||||
)
|
||||
stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
|
||||
|
||||
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
||||
def update_stable_diffusion():
|
||||
controller.refresh_stable_diffusion()
|
||||
return gr.Dropdown.update(choices=controller.stable_diffusion_list)
|
||||
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
|
||||
|
||||
with gr.Row():
|
||||
motion_module_dropdown = gr.Dropdown(
|
||||
label="Select motion module",
|
||||
choices=controller.motion_module_list,
|
||||
interactive=True,
|
||||
)
|
||||
motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
|
||||
|
||||
motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
||||
def update_motion_module():
|
||||
controller.refresh_motion_module()
|
||||
return gr.Dropdown.update(choices=controller.motion_module_list)
|
||||
motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
|
||||
|
||||
base_model_dropdown = gr.Dropdown(
|
||||
label="Select base Dreambooth model (required)",
|
||||
choices=controller.personalized_model_list,
|
||||
interactive=True,
|
||||
)
|
||||
base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
|
||||
|
||||
lora_model_dropdown = gr.Dropdown(
|
||||
label="Select LoRA model (optional)",
|
||||
choices=["none"] + controller.personalized_model_list,
|
||||
value="none",
|
||||
interactive=True,
|
||||
)
|
||||
lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
|
||||
|
||||
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
|
||||
|
||||
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
||||
def update_personalized_model():
|
||||
controller.refresh_personalized_model()
|
||||
return [
|
||||
gr.Dropdown.update(choices=controller.personalized_model_list),
|
||||
gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
|
||||
]
|
||||
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
||||
|
||||
with gr.Column(variant="panel"):
|
||||
gr.Markdown(
|
||||
"""
|
||||
### 2. Configs for AnimateDiff.
|
||||
"""
|
||||
)
|
||||
|
||||
prompt_textbox = gr.Textbox(label="Prompt", lines=2)
|
||||
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
|
||||
|
||||
with gr.Row().style(equal_height=False):
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
||||
sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
|
||||
|
||||
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
|
||||
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
|
||||
length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
|
||||
cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
|
||||
|
||||
with gr.Row():
|
||||
seed_textbox = gr.Textbox(label="Seed", value=-1)
|
||||
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
||||
seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
|
||||
|
||||
generate_button = gr.Button(value="Generate", variant='primary')
|
||||
|
||||
result_video = gr.Video(label="Generated Animation", interactive=False)
|
||||
|
||||
generate_button.click(
|
||||
fn=controller.animate,
|
||||
inputs=[
|
||||
stable_diffusion_dropdown,
|
||||
motion_module_dropdown,
|
||||
base_model_dropdown,
|
||||
lora_alpha_slider,
|
||||
prompt_textbox,
|
||||
negative_prompt_textbox,
|
||||
sampler_dropdown,
|
||||
sample_step_slider,
|
||||
width_slider,
|
||||
length_slider,
|
||||
height_slider,
|
||||
cfg_scale_slider,
|
||||
seed_textbox,
|
||||
],
|
||||
outputs=[result_video]
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo = ui()
|
||||
demo.launch(share=True)
|
||||
@@ -7,117 +7,136 @@ from omegaconf import OmegaConf
|
||||
import torch
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from diffusers import AutoencoderKL, EulerDiscreteScheduler
|
||||
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection
|
||||
|
||||
from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid
|
||||
from animatediff.utils.util import load_weights
|
||||
from animatediff.utils.util import save_videos_grid, load_weights
|
||||
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
from einops import rearrange, repeat
|
||||
|
||||
import csv, pdb, glob
|
||||
from safetensors import safe_open
|
||||
import math
|
||||
from pathlib import Path
|
||||
import torchvision
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main(args):
|
||||
*_, func_args = inspect.getargvalues(inspect.currentframe())
|
||||
func_args = dict(func_args)
|
||||
|
||||
time_str = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
savedir = f"samples/{Path(args.config).stem}-{time_str}"
|
||||
os.makedirs(savedir)
|
||||
*_, func_args = inspect.getargvalues(inspect.currentframe())
|
||||
func_args = dict(func_args)
|
||||
|
||||
time_str = datetime.datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
savedir = f"sample/{Path(args.exp_config).stem}_{args.H}_{args.W}-{time_str}"
|
||||
os.makedirs(savedir, exist_ok=True)
|
||||
|
||||
# Load Config
|
||||
exp_config = OmegaConf.load(args.exp_config)
|
||||
config = OmegaConf.load(args.base_config)
|
||||
config = OmegaConf.merge(config, exp_config)
|
||||
|
||||
config = OmegaConf.load(args.config)
|
||||
samples = []
|
||||
|
||||
sample_idx = 0
|
||||
for model_idx, (config_key, model_config) in enumerate(list(config.items())):
|
||||
|
||||
motion_modules = model_config.motion_module
|
||||
motion_modules = [motion_modules] if isinstance(motion_modules, str) else list(motion_modules)
|
||||
for motion_module in motion_modules:
|
||||
inference_config = OmegaConf.load(model_config.get("inference_config", args.inference_config))
|
||||
if config.get('base_model_path', '') != '':
|
||||
args.pretrained_model_path = config.base_model_path
|
||||
|
||||
# Load Component
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||
tokenizer_two = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer_2")
|
||||
text_encoder_two = CLIPTextModelWithProjection.from_pretrained(args.pretrained_model_path, subfolder="text_encoder_2")
|
||||
|
||||
### >>> create validation pipeline >>> ###
|
||||
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_path, subfolder="text_encoder")
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_path, subfolder="vae")
|
||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(inference_config.unet_additional_kwargs))
|
||||
# init unet model
|
||||
unet = UNet3DConditionModel.from_pretrained_2d(args.pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(config.unet_additional_kwargs))
|
||||
|
||||
if is_xformers_available(): unet.enable_xformers_memory_efficient_attention()
|
||||
else: assert False
|
||||
# Enable memory efficient attention
|
||||
if is_xformers_available() and args.xformers:
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
|
||||
pipeline = AnimationPipeline(
|
||||
vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet,
|
||||
scheduler=DDIMScheduler(**OmegaConf.to_container(inference_config.noise_scheduler_kwargs)),
|
||||
).to("cuda")
|
||||
scheduler = EulerDiscreteScheduler(timestep_spacing='leading', steps_offset=1, **config.noise_scheduler_kwargs)
|
||||
|
||||
pipeline = load_weights(
|
||||
pipeline,
|
||||
# motion module
|
||||
motion_module_path = motion_module,
|
||||
motion_module_lora_configs = model_config.get("motion_module_lora_configs", []),
|
||||
# image layers
|
||||
dreambooth_model_path = model_config.get("dreambooth_path", ""),
|
||||
lora_model_path = model_config.get("lora_model_path", ""),
|
||||
lora_alpha = model_config.get("lora_alpha", 0.8),
|
||||
).to("cuda")
|
||||
pipeline = AnimationPipeline(
|
||||
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler,
|
||||
text_encoder_2 = text_encoder_two, tokenizer_2=tokenizer_two
|
||||
).to("cuda")
|
||||
|
||||
prompts = model_config.prompt
|
||||
n_prompts = list(model_config.n_prompt) * len(prompts) if len(model_config.n_prompt) == 1 else model_config.n_prompt
|
||||
|
||||
random_seeds = model_config.get("seed", [-1])
|
||||
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
||||
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
||||
|
||||
config[config_key].random_seed = []
|
||||
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
||||
|
||||
# manually set random seed for reproduction
|
||||
if random_seed != -1: torch.manual_seed(random_seed)
|
||||
else: torch.seed()
|
||||
config[config_key].random_seed.append(torch.initial_seed())
|
||||
|
||||
print(f"current seed: {torch.initial_seed()}")
|
||||
print(f"sampling {prompt} ...")
|
||||
sample = pipeline(
|
||||
prompt,
|
||||
negative_prompt = n_prompt,
|
||||
num_inference_steps = model_config.steps,
|
||||
guidance_scale = model_config.guidance_scale,
|
||||
width = args.W,
|
||||
height = args.H,
|
||||
video_length = args.L,
|
||||
).videos
|
||||
samples.append(sample)
|
||||
# Load model weights
|
||||
pipeline = load_weights(
|
||||
pipeline = pipeline,
|
||||
motion_module_path = config.get("motion_module_path", ""),
|
||||
ckpt_path = config.get("ckpt_path", ""),
|
||||
lora_path = config.get("lora_path", ""),
|
||||
lora_alpha = config.get("lora_alpha", 0.8)
|
||||
)
|
||||
|
||||
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
||||
save_videos_grid(sample, f"{savedir}/sample/{sample_idx}-{prompt}.gif")
|
||||
print(f"save to {savedir}/sample/{prompt}.gif")
|
||||
|
||||
sample_idx += 1
|
||||
pipeline.unet = pipeline.unet.half()
|
||||
pipeline.text_encoder = pipeline.text_encoder.half()
|
||||
pipeline.text_encoder_2 = pipeline.text_encoder_2.half()
|
||||
pipeline.enable_model_cpu_offload()
|
||||
pipeline.enable_vae_slicing()
|
||||
|
||||
samples = torch.concat(samples)
|
||||
save_videos_grid(samples, f"{savedir}/sample.gif", n_rows=4)
|
||||
prompts = config.prompt
|
||||
n_prompts = config.n_prompt
|
||||
|
||||
OmegaConf.save(config, f"{savedir}/config.yaml")
|
||||
random_seeds = config.get("seed", [-1])
|
||||
random_seeds = [random_seeds] if isinstance(random_seeds, int) else list(random_seeds)
|
||||
random_seeds = random_seeds * len(prompts) if len(random_seeds) == 1 else random_seeds
|
||||
seeds = []
|
||||
samples = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for prompt_idx, (prompt, n_prompt, random_seed) in enumerate(zip(prompts, n_prompts, random_seeds)):
|
||||
# manually set random seed for reproduction
|
||||
if random_seed != -1: torch.manual_seed(random_seed)
|
||||
else: torch.seed()
|
||||
seeds.append(torch.initial_seed())
|
||||
print(f"current seed: {torch.initial_seed()}")
|
||||
print(f"sampling {prompt} ...")
|
||||
sample = pipeline(
|
||||
prompt,
|
||||
negative_prompt = n_prompt,
|
||||
num_inference_steps = config.get('steps', 100),
|
||||
guidance_scale = config.get('guidance_scale', 10),
|
||||
width = args.W,
|
||||
height = args.H,
|
||||
single_model_length = args.L,
|
||||
).videos
|
||||
samples.append(sample)
|
||||
prompt = "-".join((prompt.replace("/", "").split(" ")[:10]))
|
||||
prompt = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
|
||||
|
||||
# save video
|
||||
save_videos_grid(sample, f"{savedir}/sample/{prompt}.mp4")
|
||||
print(f"save to {savedir}/sample/{prompt}.mp4")
|
||||
|
||||
samples = torch.concat(samples)
|
||||
save_videos_grid(samples, f"{savedir}/sample-{datetime.datetime.now().strftime('%Y-%m-%dT%H-%M-%S')}.mp4", n_rows=4)
|
||||
config.seed = seeds
|
||||
OmegaConf.save(config, f"{savedir}/config.yaml")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-v1-5",)
|
||||
parser.add_argument("--inference_config", type=str, default="configs/inference/inference-v1.yaml")
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
|
||||
parser.add_argument("--L", type=int, default=16 )
|
||||
parser.add_argument("--W", type=int, default=512)
|
||||
parser.add_argument("--H", type=int, default=512)
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
parser.add_argument("--pretrained_model_path", type=str, default="models/StableDiffusion/stable-diffusion-xl-base-1.0",)
|
||||
parser.add_argument("--base_config", type=str, default="configs/inference/inference.yaml")
|
||||
parser.add_argument("--exp_config", type=str, required=True)
|
||||
|
||||
parser.add_argument("--L", type=int, default=16 )
|
||||
parser.add_argument("--W", type=int, default=1024)
|
||||
parser.add_argument("--H", type=int, default=1024)
|
||||
|
||||
parser.add_argument("--xformers", action="store_true")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
493
train.py
493
train.py
@@ -1,493 +0,0 @@
|
||||
import os
|
||||
import math
|
||||
import wandb
|
||||
import random
|
||||
import logging
|
||||
import inspect
|
||||
import argparse
|
||||
import datetime
|
||||
import subprocess
|
||||
|
||||
from pathlib import Path
|
||||
from tqdm.auto import tqdm
|
||||
from einops import rearrange
|
||||
from omegaconf import OmegaConf
|
||||
from safetensors import safe_open
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from torch.optim.swa_utils import AveragedModel
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.pipelines import StableDiffusionPipeline
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
import transformers
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from animatediff.data.dataset import WebVid10M
|
||||
from animatediff.models.unet import UNet3DConditionModel
|
||||
from animatediff.pipelines.pipeline_animation import AnimationPipeline
|
||||
from animatediff.utils.util import save_videos_grid, zero_rank_print
|
||||
|
||||
|
||||
|
||||
def init_dist(launcher="slurm", backend='nccl', port=29500, **kwargs):
|
||||
"""Initializes distributed environment."""
|
||||
if launcher == 'pytorch':
|
||||
rank = int(os.environ['RANK'])
|
||||
num_gpus = torch.cuda.device_count()
|
||||
local_rank = rank % num_gpus
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(backend=backend, **kwargs)
|
||||
|
||||
elif launcher == 'slurm':
|
||||
proc_id = int(os.environ['SLURM_PROCID'])
|
||||
ntasks = int(os.environ['SLURM_NTASKS'])
|
||||
node_list = os.environ['SLURM_NODELIST']
|
||||
num_gpus = torch.cuda.device_count()
|
||||
local_rank = proc_id % num_gpus
|
||||
torch.cuda.set_device(local_rank)
|
||||
addr = subprocess.getoutput(
|
||||
f'scontrol show hostname {node_list} | head -n1')
|
||||
os.environ['MASTER_ADDR'] = addr
|
||||
os.environ['WORLD_SIZE'] = str(ntasks)
|
||||
os.environ['RANK'] = str(proc_id)
|
||||
port = os.environ.get('PORT', port)
|
||||
os.environ['MASTER_PORT'] = str(port)
|
||||
dist.init_process_group(backend=backend)
|
||||
zero_rank_print(f"proc_id: {proc_id}; local_rank: {local_rank}; ntasks: {ntasks}; node_list: {node_list}; num_gpus: {num_gpus}; addr: {addr}; port: {port}")
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f'Not implemented launcher type: `{launcher}`!')
|
||||
|
||||
return local_rank
|
||||
|
||||
|
||||
|
||||
def main(
|
||||
image_finetune: bool,
|
||||
|
||||
name: str,
|
||||
use_wandb: bool,
|
||||
launcher: str,
|
||||
|
||||
output_dir: str,
|
||||
pretrained_model_path: str,
|
||||
|
||||
train_data: Dict,
|
||||
validation_data: Dict,
|
||||
cfg_random_null_text: bool = True,
|
||||
cfg_random_null_text_ratio: float = 0.1,
|
||||
|
||||
unet_checkpoint_path: str = "",
|
||||
unet_additional_kwargs: Dict = {},
|
||||
ema_decay: float = 0.9999,
|
||||
noise_scheduler_kwargs = None,
|
||||
|
||||
max_train_epoch: int = -1,
|
||||
max_train_steps: int = 100,
|
||||
validation_steps: int = 100,
|
||||
validation_steps_tuple: Tuple = (-1,),
|
||||
|
||||
learning_rate: float = 3e-5,
|
||||
scale_lr: bool = False,
|
||||
lr_warmup_steps: int = 0,
|
||||
lr_scheduler: str = "constant",
|
||||
|
||||
trainable_modules: Tuple[str] = (None, ),
|
||||
num_workers: int = 32,
|
||||
train_batch_size: int = 1,
|
||||
adam_beta1: float = 0.9,
|
||||
adam_beta2: float = 0.999,
|
||||
adam_weight_decay: float = 1e-2,
|
||||
adam_epsilon: float = 1e-08,
|
||||
max_grad_norm: float = 1.0,
|
||||
gradient_accumulation_steps: int = 1,
|
||||
gradient_checkpointing: bool = False,
|
||||
checkpointing_epochs: int = 5,
|
||||
checkpointing_steps: int = -1,
|
||||
|
||||
mixed_precision_training: bool = True,
|
||||
enable_xformers_memory_efficient_attention: bool = True,
|
||||
|
||||
global_seed: int = 42,
|
||||
is_debug: bool = False,
|
||||
):
|
||||
check_min_version("0.10.0.dev0")
|
||||
|
||||
# Initialize distributed training
|
||||
local_rank = init_dist(launcher=launcher)
|
||||
global_rank = dist.get_rank()
|
||||
num_processes = dist.get_world_size()
|
||||
is_main_process = global_rank == 0
|
||||
|
||||
seed = global_seed + global_rank
|
||||
torch.manual_seed(seed)
|
||||
|
||||
# Logging folder
|
||||
folder_name = "debug" if is_debug else name + datetime.datetime.now().strftime("-%Y-%m-%dT%H-%M-%S")
|
||||
output_dir = os.path.join(output_dir, folder_name)
|
||||
if is_debug and os.path.exists(output_dir):
|
||||
os.system(f"rm -rf {output_dir}")
|
||||
|
||||
*_, config = inspect.getargvalues(inspect.currentframe())
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
|
||||
if is_main_process and (not is_debug) and use_wandb:
|
||||
run = wandb.init(project="animatediff", name=folder_name, config=config)
|
||||
|
||||
# Handle the output folder creation
|
||||
if is_main_process:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/samples", exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/sanity_check", exist_ok=True)
|
||||
os.makedirs(f"{output_dir}/checkpoints", exist_ok=True)
|
||||
OmegaConf.save(config, os.path.join(output_dir, 'config.yaml'))
|
||||
|
||||
# Load scheduler, tokenizer and models.
|
||||
noise_scheduler = DDIMScheduler(**OmegaConf.to_container(noise_scheduler_kwargs))
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
|
||||
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder")
|
||||
if not image_finetune:
|
||||
unet = UNet3DConditionModel.from_pretrained_2d(
|
||||
pretrained_model_path, subfolder="unet",
|
||||
unet_additional_kwargs=OmegaConf.to_container(unet_additional_kwargs)
|
||||
)
|
||||
else:
|
||||
unet = UNet2DConditionModel.from_pretrained(pretrained_model_path, subfolder="unet")
|
||||
|
||||
# Load pretrained unet weights
|
||||
if unet_checkpoint_path != "":
|
||||
zero_rank_print(f"from checkpoint: {unet_checkpoint_path}")
|
||||
unet_checkpoint_path = torch.load(unet_checkpoint_path, map_location="cpu")
|
||||
if "global_step" in unet_checkpoint_path: zero_rank_print(f"global_step: {unet_checkpoint_path['global_step']}")
|
||||
state_dict = unet_checkpoint_path["state_dict"] if "state_dict" in unet_checkpoint_path else unet_checkpoint_path
|
||||
|
||||
m, u = unet.load_state_dict(state_dict, strict=False)
|
||||
zero_rank_print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
||||
assert len(u) == 0
|
||||
|
||||
# Freeze vae and text_encoder
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# Set unet trainable parameters
|
||||
unet.requires_grad_(False)
|
||||
for name, param in unet.named_parameters():
|
||||
for trainable_module_name in trainable_modules:
|
||||
if trainable_module_name in name:
|
||||
param.requires_grad = True
|
||||
break
|
||||
|
||||
trainable_params = list(filter(lambda p: p.requires_grad, unet.parameters()))
|
||||
optimizer = torch.optim.AdamW(
|
||||
trainable_params,
|
||||
lr=learning_rate,
|
||||
betas=(adam_beta1, adam_beta2),
|
||||
weight_decay=adam_weight_decay,
|
||||
eps=adam_epsilon,
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
zero_rank_print(f"trainable params number: {len(trainable_params)}")
|
||||
zero_rank_print(f"trainable params scale: {sum(p.numel() for p in trainable_params) / 1e6:.3f} M")
|
||||
|
||||
# Enable xformers
|
||||
if enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# Enable gradient checkpointing
|
||||
if gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
|
||||
# Move models to GPU
|
||||
vae.to(local_rank)
|
||||
text_encoder.to(local_rank)
|
||||
|
||||
# Get the training dataset
|
||||
train_dataset = WebVid10M(**train_data, is_image=image_finetune)
|
||||
distributed_sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=num_processes,
|
||||
rank=global_rank,
|
||||
shuffle=True,
|
||||
seed=global_seed,
|
||||
)
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=train_batch_size,
|
||||
shuffle=False,
|
||||
sampler=distributed_sampler,
|
||||
num_workers=num_workers,
|
||||
pin_memory=True,
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Get the training iteration
|
||||
if max_train_steps == -1:
|
||||
assert max_train_epoch != -1
|
||||
max_train_steps = max_train_epoch * len(train_dataloader)
|
||||
|
||||
if checkpointing_steps == -1:
|
||||
assert checkpointing_epochs != -1
|
||||
checkpointing_steps = checkpointing_epochs * len(train_dataloader)
|
||||
|
||||
if scale_lr:
|
||||
learning_rate = (learning_rate * gradient_accumulation_steps * train_batch_size * num_processes)
|
||||
|
||||
# Scheduler
|
||||
lr_scheduler = get_scheduler(
|
||||
lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=lr_warmup_steps * gradient_accumulation_steps,
|
||||
num_training_steps=max_train_steps * gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
# Validation pipeline
|
||||
if not image_finetune:
|
||||
validation_pipeline = AnimationPipeline(
|
||||
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler,
|
||||
).to("cuda")
|
||||
else:
|
||||
validation_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
pretrained_model_path,
|
||||
unet=unet, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=noise_scheduler, safety_checker=None,
|
||||
)
|
||||
validation_pipeline.enable_vae_slicing()
|
||||
|
||||
# DDP warpper
|
||||
unet.to(local_rank)
|
||||
unet = DDP(unet, device_ids=[local_rank], output_device=local_rank)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# Train!
|
||||
total_batch_size = train_batch_size * num_processes * gradient_accumulation_steps
|
||||
|
||||
if is_main_process:
|
||||
logging.info("***** Running training *****")
|
||||
logging.info(f" Num examples = {len(train_dataset)}")
|
||||
logging.info(f" Num Epochs = {num_train_epochs}")
|
||||
logging.info(f" Instantaneous batch size per device = {train_batch_size}")
|
||||
logging.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logging.info(f" Gradient Accumulation steps = {gradient_accumulation_steps}")
|
||||
logging.info(f" Total optimization steps = {max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, max_train_steps), disable=not is_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
# Support mixed-precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if mixed_precision_training else None
|
||||
|
||||
for epoch in range(first_epoch, num_train_epochs):
|
||||
train_dataloader.sampler.set_epoch(epoch)
|
||||
unet.train()
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
if cfg_random_null_text:
|
||||
batch['text'] = [name if random.random() > cfg_random_null_text_ratio else "" for name in batch['text']]
|
||||
|
||||
# Data batch sanity check
|
||||
if epoch == first_epoch and step == 0:
|
||||
pixel_values, texts = batch['pixel_values'].cpu(), batch['text']
|
||||
if not image_finetune:
|
||||
pixel_values = rearrange(pixel_values, "b f c h w -> b c f h w")
|
||||
for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
|
||||
pixel_value = pixel_value[None, ...]
|
||||
save_videos_grid(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.gif", rescale=True)
|
||||
else:
|
||||
for idx, (pixel_value, text) in enumerate(zip(pixel_values, texts)):
|
||||
pixel_value = pixel_value / 2. + 0.5
|
||||
torchvision.utils.save_image(pixel_value, f"{output_dir}/sanity_check/{'-'.join(text.replace('/', '').split()[:10]) if not text == '' else f'{global_rank}-{idx}'}.png")
|
||||
|
||||
### >>>> Training >>>> ###
|
||||
|
||||
# Convert videos to latent space
|
||||
pixel_values = batch["pixel_values"].to(local_rank)
|
||||
video_length = pixel_values.shape[1]
|
||||
with torch.no_grad():
|
||||
if not image_finetune:
|
||||
pixel_values = rearrange(pixel_values, "b f c h w -> (b f) c h w")
|
||||
latents = vae.encode(pixel_values).latent_dist
|
||||
latents = latents.sample()
|
||||
latents = rearrange(latents, "(b f) c h w -> b c f h w", f=video_length)
|
||||
else:
|
||||
latents = vae.encode(pixel_values).latent_dist
|
||||
latents = latents.sample()
|
||||
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
|
||||
# Sample a random timestep for each video
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.no_grad():
|
||||
prompt_ids = tokenizer(
|
||||
batch['text'], max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
).input_ids.to(latents.device)
|
||||
encoder_hidden_states = text_encoder(prompt_ids)[0]
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
# Mixed-precision training
|
||||
with torch.cuda.amp.autocast(enabled=mixed_precision_training):
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Backpropagate
|
||||
if mixed_precision_training:
|
||||
scaler.scale(loss).backward()
|
||||
""" >>> gradient clipping >>> """
|
||||
scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
||||
""" <<< gradient clipping <<< """
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
else:
|
||||
loss.backward()
|
||||
""" >>> gradient clipping >>> """
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), max_grad_norm)
|
||||
""" <<< gradient clipping <<< """
|
||||
optimizer.step()
|
||||
|
||||
lr_scheduler.step()
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
### <<<< Training <<<< ###
|
||||
|
||||
# Wandb logging
|
||||
if is_main_process and (not is_debug) and use_wandb:
|
||||
wandb.log({"train_loss": loss.item()}, step=global_step)
|
||||
|
||||
# Save checkpoint
|
||||
if is_main_process and (global_step % checkpointing_steps == 0 or step == len(train_dataloader) - 1):
|
||||
save_path = os.path.join(output_dir, f"checkpoints")
|
||||
state_dict = {
|
||||
"epoch": epoch,
|
||||
"global_step": global_step,
|
||||
"state_dict": unet.state_dict(),
|
||||
}
|
||||
if step == len(train_dataloader) - 1:
|
||||
torch.save(state_dict, os.path.join(save_path, f"checkpoint-epoch-{epoch+1}.ckpt"))
|
||||
else:
|
||||
torch.save(state_dict, os.path.join(save_path, f"checkpoint.ckpt"))
|
||||
logging.info(f"Saved state to {save_path} (global_step: {global_step})")
|
||||
|
||||
# Periodically validation
|
||||
if is_main_process and (global_step % validation_steps == 0 or global_step in validation_steps_tuple):
|
||||
samples = []
|
||||
|
||||
generator = torch.Generator(device=latents.device)
|
||||
generator.manual_seed(global_seed)
|
||||
|
||||
height = train_data.sample_size[0] if not isinstance(train_data.sample_size, int) else train_data.sample_size
|
||||
width = train_data.sample_size[1] if not isinstance(train_data.sample_size, int) else train_data.sample_size
|
||||
|
||||
prompts = validation_data.prompts[:2] if global_step < 1000 and (not image_finetune) else validation_data.prompts
|
||||
|
||||
for idx, prompt in enumerate(prompts):
|
||||
if not image_finetune:
|
||||
sample = validation_pipeline(
|
||||
prompt,
|
||||
generator = generator,
|
||||
video_length = train_data.sample_n_frames,
|
||||
height = height,
|
||||
width = width,
|
||||
**validation_data,
|
||||
).videos
|
||||
save_videos_grid(sample, f"{output_dir}/samples/sample-{global_step}/{idx}.gif")
|
||||
samples.append(sample)
|
||||
|
||||
else:
|
||||
sample = validation_pipeline(
|
||||
prompt,
|
||||
generator = generator,
|
||||
height = height,
|
||||
width = width,
|
||||
num_inference_steps = validation_data.get("num_inference_steps", 25),
|
||||
guidance_scale = validation_data.get("guidance_scale", 8.),
|
||||
).images[0]
|
||||
sample = torchvision.transforms.functional.to_tensor(sample)
|
||||
samples.append(sample)
|
||||
|
||||
if not image_finetune:
|
||||
samples = torch.concat(samples)
|
||||
save_path = f"{output_dir}/samples/sample-{global_step}.gif"
|
||||
save_videos_grid(samples, save_path)
|
||||
|
||||
else:
|
||||
samples = torch.stack(samples)
|
||||
save_path = f"{output_dir}/samples/sample-{global_step}.png"
|
||||
torchvision.utils.save_image(samples, save_path, nrow=4)
|
||||
|
||||
logging.info(f"Saved samples to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= max_train_steps:
|
||||
break
|
||||
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--launcher", type=str, choices=["pytorch", "slurm"], default="pytorch")
|
||||
parser.add_argument("--wandb", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
name = Path(args.config).stem
|
||||
config = OmegaConf.load(args.config)
|
||||
|
||||
main(name=name, launcher=args.launcher, use_wandb=args.wandb, **config)
|
||||
Reference in New Issue
Block a user