diff --git a/README.md b/README.md index 6b8067b..62e3192 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,15 @@ Then run the following commands: python -m scripts.animate --config [path to the config file] ``` +## Gradio Demo +We develop a Gradio demo to support a easier usage. To launch it, run the following commands: +``` +conda activate animatediff +python app.py +``` +By default, the demo will be run at `localhost:7860`. +
+ ## Gallery Here we demonstrate several best results we found in our experiments. diff --git a/__assets__/figs/gradio.jpg b/__assets__/figs/gradio.jpg new file mode 100644 index 0000000..19aea7c Binary files /dev/null and b/__assets__/figs/gradio.jpg differ diff --git a/app.py b/app.py new file mode 100644 index 0000000..488daf8 --- /dev/null +++ b/app.py @@ -0,0 +1,337 @@ +import gradio as gr +import os +from glob import glob +import random +import pdb +from transformers import CLIPTextModel, CLIPTokenizer +from animatediff.models.unet import UNet3DConditionModel +from animatediff.pipelines.pipeline_animation import AnimationPipeline + +from diffusers import AutoencoderKL +from datetime import datetime +import os +from omegaconf import OmegaConf +import json +import torch + +from diffusers import AutoencoderKL +from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler + +from transformers import CLIPTextModel, CLIPTokenizer + +from animatediff.models.unet import UNet3DConditionModel +from animatediff.pipelines.pipeline_animation import AnimationPipeline +from animatediff.utils.util import save_videos_grid +from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint +from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora +from diffusers.utils.import_utils import is_xformers_available + +from safetensors import safe_open + + +sample_idx = 0 + +scheduler_dict = { + "Euler": EulerDiscreteScheduler, + "PNDM": PNDMScheduler, + "DDIM": DDIMScheduler, +} + +css = """ +.toolbutton { + margin-buttom: 0em 0em 0em 0em; + max-width: 2.5em; + min-width: 2.5em !important; + height: 2.5em; +} +""" + + +class AnimateController: + def __init__(self): + + # config dirs + self.basedir = os.getcwd() + self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion") + self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") + self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA") + self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S")) + self.savedir_sample = os.path.join(self.savedir, "sample") + os.makedirs(self.savedir, exist_ok=True) + + self.stable_diffusion_list = [] + self.motion_module_list = [] + self.personalized_model_list = [] + + self.refresh_stable_diffusion() + self.refresh_motion_module() + self.refresh_personalized_model() + + # config models + self.tokenizer = None + self.text_encoder = None + self.vae = None + self.unet = None + self.pipeline = None + self.lora_model_state_dict = {} + + self.inference_config = OmegaConf.load("configs/inference/inference.yaml") + + def refresh_stable_diffusion(self): + self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/")) + + def refresh_motion_module(self): + motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) + self.motion_module_list = [os.path.basename(p) for p in motion_module_list] + + def refresh_personalized_model(self): + personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors")) + self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list] + + def update_stable_diffusion(self, stable_diffusion_dropdown): + self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda() + self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda() + self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda() + return gr.Dropdown.update() + + def update_motion_module(self, motion_module_dropdown): + if self.unet is None: + gr.Info(f"Please select a pretrained model path.") + return gr.Dropdown.update(value=None) + else: + motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown) + motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu") + missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False) + assert len(unexpected) == 0 + return gr.Dropdown.update() + + def update_base_model(self, base_model_dropdown): + if self.unet is None: + gr.Info(f"Please select a pretrained model path.") + return gr.Dropdown.update(value=None) + else: + base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown) + base_model_state_dict = {} + with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: + for key in f.keys(): + base_model_state_dict[key] = f.get_tensor(key) + + converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config) + self.vae.load_state_dict(converted_vae_checkpoint) + + converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config) + self.unet.load_state_dict(converted_unet_checkpoint, strict=False) + + self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) + return gr.Dropdown.update() + + def update_lora_model(self, lora_model_dropdown): + lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown) + self.lora_model_state_dict = {} + if lora_model_dropdown == "none": pass + else: + with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f: + for key in f.keys(): + self.lora_model_state_dict[key] = f.get_tensor(key) + return gr.Dropdown.update() + + def animate( + self, + stable_diffusion_dropdown, + motion_module_dropdown, + base_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + width_slider, + length_slider, + height_slider, + cfg_scale_slider, + seed_textbox + ): + if self.unet is None: + raise gr.Error(f"Please select a pretrained model path.") + if motion_module_dropdown == "": + raise gr.Error(f"Please select a motion module.") + if base_model_dropdown == "": + raise gr.Error(f"Please select a base DreamBooth model.") + + if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention() + + pipeline = AnimationPipeline( + vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet, + scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)) + ).to("cuda") + + if self.lora_model_state_dict != {}: + pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider) + + pipeline.to("cuda") + + if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox)) + else: torch.seed() + seed = torch.initial_seed() + + sample = pipeline( + prompt_textbox, + negative_prompt = negative_prompt_textbox, + num_inference_steps = sample_step_slider, + guidance_scale = cfg_scale_slider, + width = width_slider, + height = height_slider, + video_length = length_slider, + ).videos + + save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4") + save_videos_grid(sample, save_sample_path) + + sample_config = { + "prompt": prompt_textbox, + "n_prompt": negative_prompt_textbox, + "sampler": sampler_dropdown, + "num_inference_steps": sample_step_slider, + "guidance_scale": cfg_scale_slider, + "width": width_slider, + "height": height_slider, + "video_length": length_slider, + "seed": seed + } + json_str = json.dumps(sample_config, indent=4) + with open(os.path.join(self.savedir, "logs.json"), "a") as f: + f.write(json_str) + f.write("\n\n") + + return gr.Video.update(value=save_sample_path) + + +controller = AnimateController() + + +def ui(): + with gr.Blocks(css=css) as demo: + gr.Markdown( + """ + # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725) + Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)
+ [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/) + """ + ) + with gr.Column(variant="panel"): + gr.Markdown( + """ + ### 1. Model checkpoints (select pretrained model path first). + """ + ) + with gr.Row(): + stable_diffusion_dropdown = gr.Dropdown( + label="Pretrained Model Path", + choices=controller.stable_diffusion_list, + interactive=True, + ) + stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown]) + + stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def update_stable_diffusion(): + controller.refresh_stable_diffusion() + return gr.Dropdown.update(choices=controller.stable_diffusion_list) + stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown]) + + with gr.Row(): + motion_module_dropdown = gr.Dropdown( + label="Select motion module", + choices=controller.motion_module_list, + interactive=True, + ) + motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown]) + + motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def update_motion_module(): + controller.refresh_motion_module() + return gr.Dropdown.update(choices=controller.motion_module_list) + motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown]) + + base_model_dropdown = gr.Dropdown( + label="Select base Dreambooth model (required)", + choices=controller.personalized_model_list, + interactive=True, + ) + base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown]) + + lora_model_dropdown = gr.Dropdown( + label="Select LoRA model (optional)", + choices=["none"] + controller.personalized_model_list, + value="none", + interactive=True, + ) + lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown]) + + lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True) + + personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton") + def update_personalized_model(): + controller.refresh_personalized_model() + return [ + gr.Dropdown.update(choices=controller.personalized_model_list), + gr.Dropdown.update(choices=["none"] + controller.personalized_model_list) + ] + personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown]) + + with gr.Column(variant="panel"): + gr.Markdown( + """ + ### 2. Configs for AnimateDiff. + """ + ) + + prompt_textbox = gr.Textbox(label="Prompt", lines=2) + negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2) + + with gr.Row().style(equal_height=False): + with gr.Column(): + with gr.Row(): + sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0]) + sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1) + + width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64) + height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64) + length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1) + cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20) + + with gr.Row(): + seed_textbox = gr.Textbox(label="Seed", value=-1) + seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton") + seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox]) + + generate_button = gr.Button(value="Generate", variant='primary') + + result_video = gr.Video(label="Generated Animation", interactive=False) + + generate_button.click( + fn=controller.animate, + inputs=[ + stable_diffusion_dropdown, + motion_module_dropdown, + base_model_dropdown, + lora_alpha_slider, + prompt_textbox, + negative_prompt_textbox, + sampler_dropdown, + sample_step_slider, + width_slider, + length_slider, + height_slider, + cfg_scale_slider, + seed_textbox, + ], + outputs=[result_video] + ) + + return demo + + +if __name__ == "__main__": + demo = ui() + demo.launch(share=True) diff --git a/configs/inference/inference.yaml b/configs/inference/inference.yaml index ff92cde..86f3777 100644 --- a/configs/inference/inference.yaml +++ b/configs/inference/inference.yaml @@ -21,9 +21,6 @@ unet_additional_kwargs: temporal_attention_dim_div: 1 noise_scheduler_kwargs: - num_train_timesteps: 1000 beta_start: 0.00085 beta_end: 0.012 beta_schedule: "linear" - steps_offset: 1 - clip_sample: false