Merge pull request #47 from x-CK-x/main

Gradio WebUI
This commit is contained in:
Yuwei Guo
2023-07-19 10:43:13 +08:00
committed by GitHub
6 changed files with 614 additions and 265 deletions

488
README.md
View File

@@ -1,244 +1,244 @@
# AnimateDiff # AnimateDiff
This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725). This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725).
**[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)** **[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)**
</br> </br>
Yuwei Guo, Yuwei Guo,
Ceyuan Yang*, Ceyuan Yang*,
Anyi Rao, Anyi Rao,
Yaohui Wang, Yaohui Wang,
Yu Qiao, Yu Qiao,
Dahua Lin, Dahua Lin,
Bo Dai Bo Dai
<p style="font-size: 0.8em; margin-top: -1em">*Corresponding Author</p> <p style="font-size: 0.8em; margin-top: -1em">*Corresponding Author</p>
[Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/)
## Todo ## Todo
- [x] Code Release - [x] Code Release
- [x] Arxiv Report - [x] Arxiv Report
- [x] GPU Memory Optimization - [x] GPU Memory Optimization
- [ ] Gradio Interface - [ ] Gradio Interface
## Common Issues ## Common Issues
<details> <details>
<summary>Installation</summary> <summary>Installation</summary>
Please ensure the installation of [xformer](https://github.com/facebookresearch/xformers) that is applied to reduce the inference memory. Please ensure the installation of [xformer](https://github.com/facebookresearch/xformers) that is applied to reduce the inference memory.
</details> </details>
<details> <details>
<summary>Various resolution or number of frames</summary> <summary>Various resolution or number of frames</summary>
Currently, we recommend users to generate animation with 16 frames and 512 resolution that are aligned with our training settings. Notably, various resolution/frames may affect the quality more or less. Currently, we recommend users to generate animation with 16 frames and 512 resolution that are aligned with our training settings. Notably, various resolution/frames may affect the quality more or less.
</details> </details>
<details> <details>
<summary>Animating a given image</summary> <summary>Animating a given image</summary>
We totally agree that animating a given image is an appealing feature, which we would try to support officially in future. For now, you may enjoy other efforts from the [talesofai](https://github.com/talesofai/AnimateDiff). We totally agree that animating a given image is an appealing feature, which we would try to support officially in future. For now, you may enjoy other efforts from the [talesofai](https://github.com/talesofai/AnimateDiff).
</details> </details>
<details> <details>
<summary>Contributions from community</summary> <summary>Contributions from community</summary>
Contributions are always welcome!! We will create another branch which community could contribute to. As for the main branch, we would like to align it with the original technical report:) Contributions are always welcome!! We will create another branch which community could contribute to. As for the main branch, we would like to align it with the original technical report:)
</details> </details>
## Setup for Inference ## Setup for Inference
### Prepare Environment ### Prepare Environment
~~Our approach takes around 60 GB GPU memory to inference. NVIDIA A100 is recommanded.~~ ~~Our approach takes around 60 GB GPU memory to inference. NVIDIA A100 is recommanded.~~
***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!*** ***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!***
``` ```
git clone https://github.com/guoyww/AnimateDiff.git git clone https://github.com/guoyww/AnimateDiff.git
cd AnimateDiff cd AnimateDiff
conda env create -f environment.yaml conda env create -f environment.yaml
conda activate animatediff conda activate animatediff
``` ```
### Download Base T2I & Motion Module Checkpoints ### Download Base T2I & Motion Module Checkpoints
We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately. We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately.
It's recommanded to try both of them for best results. It's recommanded to try both of them for best results.
``` ```
git lfs install git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/ git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/
bash download_bashscripts/0-MotionModule.sh bash download_bashscripts/0-MotionModule.sh
``` ```
You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing), then put them in `models/Motion_Module/` folder. You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing), then put them in `models/Motion_Module/` folder.
### Prepare Personalize T2I ### Prepare Personalize T2I
Here we provide inference configs for 6 demo T2I on CivitAI. Here we provide inference configs for 6 demo T2I on CivitAI.
You may run the following bash scripts to download these checkpoints. You may run the following bash scripts to download these checkpoints.
``` ```
bash download_bashscripts/1-ToonYou.sh bash download_bashscripts/1-ToonYou.sh
bash download_bashscripts/2-Lyriel.sh bash download_bashscripts/2-Lyriel.sh
bash download_bashscripts/3-RcnzCartoon.sh bash download_bashscripts/3-RcnzCartoon.sh
bash download_bashscripts/4-MajicMix.sh bash download_bashscripts/4-MajicMix.sh
bash download_bashscripts/5-RealisticVision.sh bash download_bashscripts/5-RealisticVision.sh
bash download_bashscripts/6-Tusun.sh bash download_bashscripts/6-Tusun.sh
bash download_bashscripts/7-FilmVelvia.sh bash download_bashscripts/7-FilmVelvia.sh
bash download_bashscripts/8-GhibliBackground.sh bash download_bashscripts/8-GhibliBackground.sh
``` ```
### Inference ### Inference
After downloading the above peronalized T2I checkpoints, run the following commands to generate animations. The results will automatically be saved to `samples/` folder. After downloading the above peronalized T2I checkpoints, run the following commands to generate animations. The results will automatically be saved to `samples/` folder.
``` ```
python -m scripts.animate --config configs/prompts/1-ToonYou.yaml python -m scripts.animate --config configs/prompts/1-ToonYou.yaml
python -m scripts.animate --config configs/prompts/2-Lyriel.yaml python -m scripts.animate --config configs/prompts/2-Lyriel.yaml
python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml
python -m scripts.animate --config configs/prompts/4-MajicMix.yaml python -m scripts.animate --config configs/prompts/4-MajicMix.yaml
python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml
python -m scripts.animate --config configs/prompts/6-Tusun.yaml python -m scripts.animate --config configs/prompts/6-Tusun.yaml
python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml
python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml
``` ```
To generate animations with a new DreamBooth/LoRA model, you may create a new config `.yaml` file in the following format: To generate animations with a new DreamBooth/LoRA model, you may create a new config `.yaml` file in the following format:
``` ```
NewModel: NewModel:
path: "[path to your DreamBooth/LoRA model .safetensors file]" path: "[path to your DreamBooth/LoRA model .safetensors file]"
base: "[path to LoRA base model .safetensors file, leave it empty string if not needed]" base: "[path to LoRA base model .safetensors file, leave it empty string if not needed]"
motion_module: motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt" - "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt" - "models/Motion_Module/mm_sd_v15.ckpt"
steps: 25 steps: 25
guidance_scale: 7.5 guidance_scale: 7.5
prompt: prompt:
- "[positive prompt]" - "[positive prompt]"
n_prompt: n_prompt:
- "[negative prompt]" - "[negative prompt]"
``` ```
Then run the following commands: Then run the following commands:
``` ```
python -m scripts.animate --config [path to the config file] python -m scripts.animate --config [path to the config file]
``` ```
## Gradio Demo ## Gradio Demo
We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands: We have created a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
``` ```
conda activate animatediff conda activate animatediff
python app.py python app.py
``` ```
By default, the demo will run at `localhost:7860`. By default, the demo will run at `localhost:7860`.
<br><img src="__assets__/figs/gradio.jpg" style="width: 50em; margin-top: 1em"> <br><img src="__assets__/figs/gradio.jpg" style="width: 50em; margin-top: 1em">
## Gallery ## Gallery
Here we demonstrate several best results we found in our experiments. Here we demonstrate several best results we found in our experiments.
<table class="center"> <table class="center">
<tr> <tr>
<td><img src="__assets__/animations/model_01/01.gif"></td> <td><img src="__assets__/animations/model_01/01.gif"></td>
<td><img src="__assets__/animations/model_01/02.gif"></td> <td><img src="__assets__/animations/model_01/02.gif"></td>
<td><img src="__assets__/animations/model_01/03.gif"></td> <td><img src="__assets__/animations/model_01/03.gif"></td>
<td><img src="__assets__/animations/model_01/04.gif"></td> <td><img src="__assets__/animations/model_01/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/30240/toonyou">ToonYou</a></p> <p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/30240/toonyou">ToonYou</a></p>
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_02/01.gif"></td> <td><img src="__assets__/animations/model_02/01.gif"></td>
<td><img src="__assets__/animations/model_02/02.gif"></td> <td><img src="__assets__/animations/model_02/02.gif"></td>
<td><img src="__assets__/animations/model_02/03.gif"></td> <td><img src="__assets__/animations/model_02/03.gif"></td>
<td><img src="__assets__/animations/model_02/04.gif"></td> <td><img src="__assets__/animations/model_02/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4468/counterfeit-v30">Counterfeit V3.0</a></p> <p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4468/counterfeit-v30">Counterfeit V3.0</a></p>
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_03/01.gif"></td> <td><img src="__assets__/animations/model_03/01.gif"></td>
<td><img src="__assets__/animations/model_03/02.gif"></td> <td><img src="__assets__/animations/model_03/02.gif"></td>
<td><img src="__assets__/animations/model_03/03.gif"></td> <td><img src="__assets__/animations/model_03/03.gif"></td>
<td><img src="__assets__/animations/model_03/04.gif"></td> <td><img src="__assets__/animations/model_03/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4201/realistic-vision-v20">Realistic Vision V2.0</a></p> <p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4201/realistic-vision-v20">Realistic Vision V2.0</a></p>
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_04/01.gif"></td> <td><img src="__assets__/animations/model_04/01.gif"></td>
<td><img src="__assets__/animations/model_04/02.gif"></td> <td><img src="__assets__/animations/model_04/02.gif"></td>
<td><img src="__assets__/animations/model_04/03.gif"></td> <td><img src="__assets__/animations/model_04/03.gif"></td>
<td><img src="__assets__/animations/model_04/04.gif"></td> <td><img src="__assets__/animations/model_04/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model <a href="https://civitai.com/models/43331/majicmix-realistic">majicMIX Realistic</a></p> <p style="margin-left: 2em; margin-top: -1em">Model <a href="https://civitai.com/models/43331/majicmix-realistic">majicMIX Realistic</a></p>
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_05/01.gif"></td> <td><img src="__assets__/animations/model_05/01.gif"></td>
<td><img src="__assets__/animations/model_05/02.gif"></td> <td><img src="__assets__/animations/model_05/02.gif"></td>
<td><img src="__assets__/animations/model_05/03.gif"></td> <td><img src="__assets__/animations/model_05/03.gif"></td>
<td><img src="__assets__/animations/model_05/04.gif"></td> <td><img src="__assets__/animations/model_05/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/66347/rcnz-cartoon-3d">RCNZ Cartoon</a></p> <p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/66347/rcnz-cartoon-3d">RCNZ Cartoon</a></p>
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_06/01.gif"></td> <td><img src="__assets__/animations/model_06/01.gif"></td>
<td><img src="__assets__/animations/model_06/02.gif"></td> <td><img src="__assets__/animations/model_06/02.gif"></td>
<td><img src="__assets__/animations/model_06/03.gif"></td> <td><img src="__assets__/animations/model_06/03.gif"></td>
<td><img src="__assets__/animations/model_06/04.gif"></td> <td><img src="__assets__/animations/model_06/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/33208/filmgirl-film-grain-lora-and-loha">FilmVelvia</a></p> <p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/33208/filmgirl-film-grain-lora-and-loha">FilmVelvia</a></p>
#### Community Cases #### Community Cases
Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚. Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚.
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_07/init.jpg"></td> <td><img src="__assets__/animations/model_07/init.jpg"></td>
<td><img src="__assets__/animations/model_07/01.gif"></td> <td><img src="__assets__/animations/model_07/01.gif"></td>
<td><img src="__assets__/animations/model_07/02.gif"></td> <td><img src="__assets__/animations/model_07/02.gif"></td>
<td><img src="__assets__/animations/model_07/03.gif"></td> <td><img src="__assets__/animations/model_07/03.gif"></td>
<td><img src="__assets__/animations/model_07/04.gif"></td> <td><img src="__assets__/animations/model_07/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em"> <p style="margin-left: 2em; margin-top: -1em">
Character Model<a href="https://civitai.com/models/13237/genshen-impact-yoimiya">Yoimiya</a> Character Model<a href="https://civitai.com/models/13237/genshen-impact-yoimiya">Yoimiya</a>
(with an initial reference image, see <a href="https://github.com/talesofai/AnimateDiff">WIP fork</a> for the extended implementation.) (with an initial reference image, see <a href="https://github.com/talesofai/AnimateDiff">WIP fork</a> for the extended implementation.)
<table> <table>
<tr> <tr>
<td><img src="__assets__/animations/model_08/01.gif"></td> <td><img src="__assets__/animations/model_08/01.gif"></td>
<td><img src="__assets__/animations/model_08/02.gif"></td> <td><img src="__assets__/animations/model_08/02.gif"></td>
<td><img src="__assets__/animations/model_08/03.gif"></td> <td><img src="__assets__/animations/model_08/03.gif"></td>
<td><img src="__assets__/animations/model_08/04.gif"></td> <td><img src="__assets__/animations/model_08/04.gif"></td>
</tr> </tr>
</table> </table>
<p style="margin-left: 2em; margin-top: -1em"> <p style="margin-left: 2em; margin-top: -1em">
Character Model<a href="https://civitai.com/models/9850/paimon-genshin-impact">Paimon</a>; Character Model<a href="https://civitai.com/models/9850/paimon-genshin-impact">Paimon</a>;
Pose Model<a href="https://civitai.com/models/107295/or-holdingsign">Hold Sign</a></p> Pose Model<a href="https://civitai.com/models/107295/or-holdingsign">Hold Sign</a></p>
## BibTeX ## BibTeX
``` ```
@article{guo2023animatediff, @article{guo2023animatediff,
title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning},
author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Wang, Yaohui and Qiao, Yu and Lin, Dahua and Dai, Bo}, author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Wang, Yaohui and Qiao, Yu and Lin, Dahua and Dai, Bo},
journal={arXiv preprint arXiv:2307.04725}, journal={arXiv preprint arXiv:2307.04725},
year={2023} year={2023}
} }
``` ```
## Contact Us ## Contact Us
**Yuwei Guo**: [guoyuwei@pjlab.org.cn](mailto:guoyuwei@pjlab.org.cn) **Yuwei Guo**: [guoyuwei@pjlab.org.cn](mailto:guoyuwei@pjlab.org.cn)
**Ceyuan Yang**: [yangceyuan@pjlab.org.cn](mailto:yangceyuan@pjlab.org.cn) **Ceyuan Yang**: [yangceyuan@pjlab.org.cn](mailto:yangceyuan@pjlab.org.cn)
**Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn) **Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn)
## Acknowledgements ## Acknowledgements
Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video). Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video).

BIN
__assets__/ui/first-tab.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 84 KiB

View File

@@ -1,21 +1,22 @@
name: animatediff name: animatediff
channels: channels:
- pytorch - pytorch
- xformers - xformers
dependencies: dependencies:
- python=3.10 - python=3.10
- pytorch==1.12.1 - pytorch==1.12.1
- torchvision==0.13.1 - torchvision==0.13.1
- torchaudio==0.12.1 - torchaudio==0.12.1
- cudatoolkit=11.3 - cudatoolkit=11.3
- xformers - xformers
- pip - pip
- pip: - pip:
- diffusers[torch]==0.11.1 - diffusers[torch]==0.11.1
- transformers==4.25.1 - transformers==4.25.1
- imageio==2.27.0 - imageio==2.27.0
- gdown - gdown
- einops - einops
- omegaconf - omegaconf
- safetensors - safetensors
- gradio - gradio
- pyyaml

41
utils.py Normal file
View File

@@ -0,0 +1,41 @@
import copy
import os
import subprocess as sub
import gdown
import yaml
def yaml_to_dict(filename):
with open(filename, 'r') as file:
data = yaml.safe_load(file)
return copy.deepcopy(data)
def dict_to_yaml(data, filename):
with open(filename, 'w') as file:
yaml.safe_dump(data, file)
def download_from_drive_gdown(file_id, output_path):
url = f'https://drive.google.com/uc?id={file_id}'
gdown.download(url, output_path, quiet=False)
def verbose_print(text):
print(f"{text}")
def create_dirs(arb_path):
if not os.path.exists(arb_path):
os.makedirs(arb_path)
def make_all_dirs(list_of_paths):
for path in list_of_paths:
create_dirs(path)
def execute(cmd):
popen = sub.Popen(cmd, stdout=sub.PIPE, universal_newlines=True)
for stdout_line in iter(popen.stdout.readline, ""):
yield stdout_line
popen.stdout.close()
return_code = popen.wait()
if return_code:
raise sub.CalledProcessError(return_code, cmd)
def is_windows():
return os.name == 'nt'

307
webui.py Normal file
View File

@@ -0,0 +1,307 @@
import argparse
import gradio as gr
import copy
import os
import glob
import utils as help
sep = '\\' if help.is_windows() else '/'
all_motion_model_opts = ["mm_sd_v14.ckpt", "mm_sd_v15.ckpt"]
def get_available_motion_models():
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
return motion_model_opts
def get_available_sd_models():
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
return sd_model_opts
def get_available_db_models():
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
return db_model_opts
def get_db_config():
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
return sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))])
def get_sd_config():
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
return sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))])
def set_motion_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
motion_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "Motion_Module"))
motion_model_opts = sorted([ckpt for ckpt in glob.glob(os.path.join(motion_model_opts_path, f"*.ckpt"))])
motion_model_map = {"mm_sd_v14.ckpt": "1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq",
"mm_sd_v15.ckpt": "1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu"}
if not os.path.join(motion_model_opts_path, model_name) in motion_model_opts: # download
help.download_from_drive_gdown(motion_model_map[model_name], os.path.join(motion_model_opts_path, model_name))
return gr.update(value=os.path.join(motion_model_opts_path, model_name)) # model path
def set_sd_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
return gr.update(value=os.path.join(sd_model_opts_path, model_name)), gr.update(value=os.path.join(sd_model_opts_path, model_name)) # sd path, pretrained path
def set_db_model(menu_opt: gr.SelectData):
model_name = menu_opt.value
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
return gr.update(value=os.path.join(db_model_opts_path, model_name)) # db path
def update_available_sd_models():
sd_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "StableDiffusion"))
sd_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(sd_model_opts_path, f"*.safetensors"))])
return gr.update(choices=sd_model_opts)
def update_available_db_models():
db_model_opts_path = os.path.join(os.getcwd(), os.path.join("models", "DreamBooth_LoRA"))
db_model_opts = sorted([safetensor.split(sep)[-1] for safetensor in glob.glob(os.path.join(db_model_opts_path, f"*.safetensors"))])
return gr.update(choices=db_model_opts)
def update_sd_config():
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
return gr.update(choices=sorted([(inference_yaml.split(sep)[-1]) for inference_yaml in glob.glob(os.path.join(inference_configs_path, f"*.yaml"))]))
def update_db_config():
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
return gr.update(choices=sorted([(prompt_yaml.split(sep)[-1]) for prompt_yaml in glob.glob(os.path.join(prompt_configs_path, f"*.yaml"))]))
def load_db_config(filename: gr.SelectData):
filename = filename.value
global prompt_config_dict
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
# populate the dictionary
prompt_config_dict = help.yaml_to_dict(os.path.join(prompt_configs_path, f"{filename}"))
name_only = list(prompt_config_dict.keys())[0]
help.verbose_print(f"Config Key Name:\t{name_only}")
# return populated UI components
config_name = name_only
motion_model_path = list(prompt_config_dict[name_only]["motion_module"])[0]
base_path = str(prompt_config_dict[name_only]["base"])
db_path = str(prompt_config_dict[name_only]["path"])
steps = int(prompt_config_dict[name_only]["steps"])
guidance_scale = float(prompt_config_dict[name_only]["guidance_scale"])
lora_alpha = float(prompt_config_dict[name_only]["lora_alpha"]) if "lora_alpha" in prompt_config_dict[name_only] else 1.0
seed_list = list(prompt_config_dict[name_only]["seed"])
prompt_list = list(prompt_config_dict[name_only]["prompt"])
n_prompt_list = list(prompt_config_dict[name_only]["n_prompt"])
seed1 = str(seed_list[0]) if len(seed_list) > 0 else "-1"
prompt1 = str(prompt_list[0]) if len(prompt_list) > 0 else ""
n_prompt1 = str(n_prompt_list[0]) if len(n_prompt_list) > 0 else ""
seed2 = str(seed_list[1]) if len(seed_list) > 1 else "-1"
prompt2 = str(prompt_list[1]) if len(prompt_list) > 1 else ""
n_prompt2 = str(n_prompt_list[1]) if len(n_prompt_list) > 1 else ""
seed3 = str(seed_list[2]) if len(seed_list) > 2 else "-1"
prompt3 = str(prompt_list[2]) if len(prompt_list) > 2 else ""
n_prompt3 = str(n_prompt_list[2]) if len(n_prompt_list) > 2 else ""
seed4 = str(seed_list[3]) if len(seed_list) > 3 else "-1"
prompt4 = str(prompt_list[3]) if len(prompt_list) > 3 else ""
n_prompt4 = str(n_prompt_list[3]) if len(n_prompt_list) > 3 else ""
help.verbose_print(f"Done Loading Prompt Config!")
motion_model_dropdown = gr.update(value=motion_model_path.split(sep)[-1])
sd_model_dropdown = gr.update(value=base_path.split(sep)[-1])
db_model_dropdown = gr.update(value=db_path.split(sep)[-1])
pretrained_model_path = gr.update(value=base_path)
return config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha, \
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4, \
motion_model_dropdown, sd_model_dropdown, db_model_dropdown, pretrained_model_path
def save_db_config(filename):
global prompt_config_dict
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{filename}.yaml"))
help.verbose_print(f"Done Creating NEW Prompt Config!")
def save_prompt_dict(config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4):
global prompt_config_dict
prompt_config_dict[config_name] = {}
prompt_config_dict[config_name]["base"] = base_path
prompt_config_dict[config_name]["path"] = db_path
prompt_config_dict[config_name]["motion_module"] = []
prompt_config_dict[config_name]["motion_module"].append(motion_model_path)
prompt_config_dict[config_name]["seed"] = [0] * 4
prompt_config_dict[config_name]["steps"] = steps
prompt_config_dict[config_name]["guidance_scale"] = guidance_scale
prompt_config_dict[config_name]["lora_alpha"] = lora_alpha
prompt_config_dict[config_name]["prompt"] = [""]*4
prompt_config_dict[config_name]["n_prompt"] = [""]*4
prompt_config_dict[config_name]["seed"][0] = int(seed1)
prompt_config_dict[config_name]["prompt"][0] = prompt1
prompt_config_dict[config_name]["n_prompt"][0] = n_prompt1
prompt_config_dict[config_name]["seed"][1] = int(seed2)
prompt_config_dict[config_name]["prompt"][1] = prompt2
prompt_config_dict[config_name]["n_prompt"][1] = n_prompt2
prompt_config_dict[config_name]["seed"][2] = int(seed3)
prompt_config_dict[config_name]["prompt"][2] = prompt3
prompt_config_dict[config_name]["n_prompt"][2] = n_prompt3
prompt_config_dict[config_name]["seed"][3] = int(seed4)
prompt_config_dict[config_name]["prompt"][3] = prompt4
prompt_config_dict[config_name]["n_prompt"][3] = n_prompt4
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
help.dict_to_yaml(copy.deepcopy(prompt_config_dict), os.path.join(prompt_configs_path, f"{config_name}.yaml"))
help.verbose_print(f"Done Updating Prompt Config!")
def animate(pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select):
prompt_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "prompts")
inference_configs_path = os.path.join(os.path.join(os.getcwd(), "configs"), "inference")
command_str = f"python -m scripts.animate --config {os.path.join(prompt_configs_path, prompt_yaml_select)}"
if pretrained_model_path is not None and len(pretrained_model_path) > 0:
command_str += f" --pretrained_model_path {pretrained_model_path}"
command_str += f" --L {frame_count}"
command_str += f" --W {width}"
command_str += f" --H {height}"
if inference_yaml_select is not None and len(inference_yaml_select) > 0:
command_str += f" --inference_config {os.path.join(inference_configs_path, inference_yaml_select)}"
help.verbose_print(f"Running Command:\t{command_str}")
for line in help.execute(command_str.split(" ")):
help.verbose_print(line)
help.verbose_print(f"Done Generating!")
def build_ui():
with gr.Blocks() as demo:
with gr.Tab("Model Selection & Setup"):
with gr.Row():
motion_model_dropdown = gr.Dropdown(interactive=True, label="Select Motion Model", info="Downloads model if not present", choices=all_motion_model_opts)
sd_model_dropdown = gr.Dropdown(interactive=True, label="Select Stable Diffusion Model", info="At user/s discretion to download", choices=get_available_sd_models())
db_model_dropdown = gr.Dropdown(interactive=True, label="Select LoRA/Dreambooth Model", info="At user/s discretion to download", choices=get_available_db_models())
with gr.Row():
pretrained_model_path = gr.Textbox(info="Pretrained Model Path", interactive=True, show_label=False)
with gr.Row():
frame_count = gr.Slider(info="Total Frames", minimum=0, maximum=1000, step=1, value=16, show_label=False)
width = gr.Slider(info="Width", minimum=0, maximum=4096, step=1, value=512, show_label=False)
height = gr.Slider(info="Height", minimum=0, maximum=4096, step=1, value=512, show_label=False)
with gr.Row():
inference_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_sd_config(), show_label=False)
animate_button = gr.Button(value="Generate", variant='primary')
with gr.Tab("LoRA/Dreambooth Prompt Config"):
with gr.Row():
config_save = gr.Button(value="Apply & Save Settings", variant='primary')
create_prompt_yaml = gr.Button(value="Create Prompt Config", variant='secondary')
with gr.Row():
prompt_yaml_select = gr.Dropdown(info='YAML Select', interactive=True, choices=get_db_config(), show_label=False)
config_name = gr.Textbox(info="Config Name", interactive=True, show_label=False)
motion_model_path = gr.Textbox(info="Motion Model Path", interactive=True, show_label=False)
with gr.Row():
base_path = gr.Textbox(info="Base Model Path", interactive=True, show_label=False)
db_path = gr.Textbox(info="LoRA/Dreambooth Path", interactive=True, show_label=False)
with gr.Row():
steps = gr.Slider(info="Steps", minimum=0, maximum=1000, step=1, value=25, show_label=False)
with gr.Row():
guidance_scale = gr.Slider(info="Guidance Scale", minimum=0.0, maximum=100.0, step=0.05, value=6.5, show_label=False)
lora_alpha = gr.Slider(info="LoRA Alpha", minimum=0.0, maximum=1.0, step=0.025, value=1.0, show_label=False)
with gr.Accordion("Prompt 1", visible=True, open=False):
with gr.Column():
seed1 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt1 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt1 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 2", visible=True, open=False):
with gr.Column():
seed2 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt2 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt2 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 3", visible=True, open=False):
with gr.Column():
seed3 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt3 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt3 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
with gr.Accordion("Prompt 4", visible=True, open=False):
with gr.Column():
seed4 = gr.Textbox(info="Seed", interactive=True, show_label=False)
prompt4 = gr.Textbox(info="Prompt", interactive=True, show_label=False)
n_prompt4 = gr.Textbox(info="Negative Prompt", interactive=True, show_label=False)
motion_model_dropdown.select(fn=set_motion_model, inputs=[], outputs=[motion_model_path])
sd_model_dropdown.select(fn=set_sd_model, inputs=[], outputs=[base_path, pretrained_model_path])
db_model_dropdown.select(fn=set_db_model, inputs=[], outputs=[db_path])
prompt_yaml_select.select(fn=load_db_config, inputs=[],
outputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3,
seed4, prompt4, n_prompt4, motion_model_dropdown, sd_model_dropdown,
db_model_dropdown, pretrained_model_path]).then(
fn=update_db_config, inputs=[], outputs=[prompt_yaml_select])
create_prompt_yaml.click(fn=save_db_config, inputs=[config_name], outputs=[])
config_save.click(fn=save_prompt_dict, inputs=[config_name, motion_model_path, base_path, db_path, steps, guidance_scale, lora_alpha,
seed1, prompt1, n_prompt1, seed2, prompt2, n_prompt2, seed3, prompt3, n_prompt3, seed4, prompt4, n_prompt4],
outputs=[])
animate_button.click(fn=animate, inputs=[pretrained_model_path, frame_count, width, height, inference_yaml_select, prompt_yaml_select], outputs=[])
return demo
def UI(**kwargs):
# Show the interface
launch_kwargs = {}
if not kwargs.get('username', None) == '':
launch_kwargs['auth'] = (
kwargs.get('username', None),
kwargs.get('password', None),
)
if kwargs.get('server_port', 0) > 0:
launch_kwargs['server_port'] = kwargs.get('server_port', 0)
if kwargs.get('share', True):
launch_kwargs['share'] = True
print(launch_kwargs)
demo.queue().launch(**launch_kwargs)
if __name__ == "__main__":
# init client & server connection
HOST = "127.0.0.1"
parser = argparse.ArgumentParser()
parser.add_argument(
'--username', type=str, default='', help='Username for authentication'
)
parser.add_argument(
'--password', type=str, default='', help='Password for authentication'
)
parser.add_argument(
'--server_port',
type=int,
default=0,
help='Port to run the server listener on',
)
parser.add_argument(
'--share',
action='store_true',
help='Share live gradio link',
)
args = parser.parse_args()
demo = build_ui()
global prompt_config_dict
prompt_config_dict = {}
help.verbose_print(f"Motion models available to use:\t{get_available_motion_models()}")
help.verbose_print(f"Stable Diffusion models available to use:\t{get_available_sd_models()}")
help.verbose_print(f"LoRA/Dreambooth models available to use:\t{get_available_db_models()}")
UI(
username=args.username,
password=args.password,
server_port=args.server_port,
share=args.share,
)