diff --git a/README.md b/README.md
index 6b8067b..62e3192 100644
--- a/README.md
+++ b/README.md
@@ -124,6 +124,15 @@ Then run the following commands:
python -m scripts.animate --config [path to the config file]
```
+## Gradio Demo
+We develop a Gradio demo to support a easier usage. To launch it, run the following commands:
+```
+conda activate animatediff
+python app.py
+```
+By default, the demo will be run at `localhost:7860`.
+
+
## Gallery
Here we demonstrate several best results we found in our experiments.
diff --git a/__assets__/figs/gradio.jpg b/__assets__/figs/gradio.jpg
new file mode 100644
index 0000000..19aea7c
Binary files /dev/null and b/__assets__/figs/gradio.jpg differ
diff --git a/app.py b/app.py
new file mode 100644
index 0000000..488daf8
--- /dev/null
+++ b/app.py
@@ -0,0 +1,337 @@
+import gradio as gr
+import os
+from glob import glob
+import random
+import pdb
+from transformers import CLIPTextModel, CLIPTokenizer
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+
+from diffusers import AutoencoderKL
+from datetime import datetime
+import os
+from omegaconf import OmegaConf
+import json
+import torch
+
+from diffusers import AutoencoderKL
+from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
+
+from transformers import CLIPTextModel, CLIPTokenizer
+
+from animatediff.models.unet import UNet3DConditionModel
+from animatediff.pipelines.pipeline_animation import AnimationPipeline
+from animatediff.utils.util import save_videos_grid
+from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
+from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
+from diffusers.utils.import_utils import is_xformers_available
+
+from safetensors import safe_open
+
+
+sample_idx = 0
+
+scheduler_dict = {
+ "Euler": EulerDiscreteScheduler,
+ "PNDM": PNDMScheduler,
+ "DDIM": DDIMScheduler,
+}
+
+css = """
+.toolbutton {
+ margin-buttom: 0em 0em 0em 0em;
+ max-width: 2.5em;
+ min-width: 2.5em !important;
+ height: 2.5em;
+}
+"""
+
+
+class AnimateController:
+ def __init__(self):
+
+ # config dirs
+ self.basedir = os.getcwd()
+ self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
+ self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
+ self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
+ self.savedir_sample = os.path.join(self.savedir, "sample")
+ os.makedirs(self.savedir, exist_ok=True)
+
+ self.stable_diffusion_list = []
+ self.motion_module_list = []
+ self.personalized_model_list = []
+
+ self.refresh_stable_diffusion()
+ self.refresh_motion_module()
+ self.refresh_personalized_model()
+
+ # config models
+ self.tokenizer = None
+ self.text_encoder = None
+ self.vae = None
+ self.unet = None
+ self.pipeline = None
+ self.lora_model_state_dict = {}
+
+ self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
+
+ def refresh_stable_diffusion(self):
+ self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
+
+ def refresh_motion_module(self):
+ motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
+ self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
+
+ def refresh_personalized_model(self):
+ personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
+ self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
+
+ def update_stable_diffusion(self, stable_diffusion_dropdown):
+ self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
+ self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
+ self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
+ self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
+ return gr.Dropdown.update()
+
+ def update_motion_module(self, motion_module_dropdown):
+ if self.unet is None:
+ gr.Info(f"Please select a pretrained model path.")
+ return gr.Dropdown.update(value=None)
+ else:
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
+ missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
+ assert len(unexpected) == 0
+ return gr.Dropdown.update()
+
+ def update_base_model(self, base_model_dropdown):
+ if self.unet is None:
+ gr.Info(f"Please select a pretrained model path.")
+ return gr.Dropdown.update(value=None)
+ else:
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
+ base_model_state_dict = {}
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ base_model_state_dict[key] = f.get_tensor(key)
+
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
+ self.vae.load_state_dict(converted_vae_checkpoint)
+
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
+
+ self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
+ return gr.Dropdown.update()
+
+ def update_lora_model(self, lora_model_dropdown):
+ lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
+ self.lora_model_state_dict = {}
+ if lora_model_dropdown == "none": pass
+ else:
+ with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
+ for key in f.keys():
+ self.lora_model_state_dict[key] = f.get_tensor(key)
+ return gr.Dropdown.update()
+
+ def animate(
+ self,
+ stable_diffusion_dropdown,
+ motion_module_dropdown,
+ base_model_dropdown,
+ lora_alpha_slider,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ length_slider,
+ height_slider,
+ cfg_scale_slider,
+ seed_textbox
+ ):
+ if self.unet is None:
+ raise gr.Error(f"Please select a pretrained model path.")
+ if motion_module_dropdown == "":
+ raise gr.Error(f"Please select a motion module.")
+ if base_model_dropdown == "":
+ raise gr.Error(f"Please select a base DreamBooth model.")
+
+ if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
+
+ pipeline = AnimationPipeline(
+ vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
+ scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
+ ).to("cuda")
+
+ if self.lora_model_state_dict != {}:
+ pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
+
+ pipeline.to("cuda")
+
+ if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
+ else: torch.seed()
+ seed = torch.initial_seed()
+
+ sample = pipeline(
+ prompt_textbox,
+ negative_prompt = negative_prompt_textbox,
+ num_inference_steps = sample_step_slider,
+ guidance_scale = cfg_scale_slider,
+ width = width_slider,
+ height = height_slider,
+ video_length = length_slider,
+ ).videos
+
+ save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
+ save_videos_grid(sample, save_sample_path)
+
+ sample_config = {
+ "prompt": prompt_textbox,
+ "n_prompt": negative_prompt_textbox,
+ "sampler": sampler_dropdown,
+ "num_inference_steps": sample_step_slider,
+ "guidance_scale": cfg_scale_slider,
+ "width": width_slider,
+ "height": height_slider,
+ "video_length": length_slider,
+ "seed": seed
+ }
+ json_str = json.dumps(sample_config, indent=4)
+ with open(os.path.join(self.savedir, "logs.json"), "a") as f:
+ f.write(json_str)
+ f.write("\n\n")
+
+ return gr.Video.update(value=save_sample_path)
+
+
+controller = AnimateController()
+
+
+def ui():
+ with gr.Blocks(css=css) as demo:
+ gr.Markdown(
+ """
+ # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
+ Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)
+ [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
+ """
+ )
+ with gr.Column(variant="panel"):
+ gr.Markdown(
+ """
+ ### 1. Model checkpoints (select pretrained model path first).
+ """
+ )
+ with gr.Row():
+ stable_diffusion_dropdown = gr.Dropdown(
+ label="Pretrained Model Path",
+ choices=controller.stable_diffusion_list,
+ interactive=True,
+ )
+ stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
+
+ stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_stable_diffusion():
+ controller.refresh_stable_diffusion()
+ return gr.Dropdown.update(choices=controller.stable_diffusion_list)
+ stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
+
+ with gr.Row():
+ motion_module_dropdown = gr.Dropdown(
+ label="Select motion module",
+ choices=controller.motion_module_list,
+ interactive=True,
+ )
+ motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
+
+ motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_motion_module():
+ controller.refresh_motion_module()
+ return gr.Dropdown.update(choices=controller.motion_module_list)
+ motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
+
+ base_model_dropdown = gr.Dropdown(
+ label="Select base Dreambooth model (required)",
+ choices=controller.personalized_model_list,
+ interactive=True,
+ )
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
+
+ lora_model_dropdown = gr.Dropdown(
+ label="Select LoRA model (optional)",
+ choices=["none"] + controller.personalized_model_list,
+ value="none",
+ interactive=True,
+ )
+ lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
+
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
+
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
+ def update_personalized_model():
+ controller.refresh_personalized_model()
+ return [
+ gr.Dropdown.update(choices=controller.personalized_model_list),
+ gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
+ ]
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
+
+ with gr.Column(variant="panel"):
+ gr.Markdown(
+ """
+ ### 2. Configs for AnimateDiff.
+ """
+ )
+
+ prompt_textbox = gr.Textbox(label="Prompt", lines=2)
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
+
+ with gr.Row().style(equal_height=False):
+ with gr.Column():
+ with gr.Row():
+ sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
+ sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
+
+ width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
+ height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
+ length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
+
+ with gr.Row():
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
+
+ generate_button = gr.Button(value="Generate", variant='primary')
+
+ result_video = gr.Video(label="Generated Animation", interactive=False)
+
+ generate_button.click(
+ fn=controller.animate,
+ inputs=[
+ stable_diffusion_dropdown,
+ motion_module_dropdown,
+ base_model_dropdown,
+ lora_alpha_slider,
+ prompt_textbox,
+ negative_prompt_textbox,
+ sampler_dropdown,
+ sample_step_slider,
+ width_slider,
+ length_slider,
+ height_slider,
+ cfg_scale_slider,
+ seed_textbox,
+ ],
+ outputs=[result_video]
+ )
+
+ return demo
+
+
+if __name__ == "__main__":
+ demo = ui()
+ demo.launch(share=True)
diff --git a/configs/inference/inference.yaml b/configs/inference/inference.yaml
index ff92cde..86f3777 100644
--- a/configs/inference/inference.yaml
+++ b/configs/inference/inference.yaml
@@ -21,9 +21,6 @@ unet_additional_kwargs:
temporal_attention_dim_div: 1
noise_scheduler_kwargs:
- num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
- steps_offset: 1
- clip_sample: false