diff --git a/README.md b/README.md index aaa6525..4224071 100644 --- a/README.md +++ b/README.md @@ -1,121 +1,10 @@ -# AnimateDiff +# A beta-version of motion module for SDXL -This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725). - -**[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)** -
-Yuwei Guo, -Ceyuan Yang*, -Anyi Rao, -Yaohui Wang, -Yu Qiao, -Dahua Lin, -Bo Dai -

*Corresponding Author

- - -[![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2307.04725) -[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://animatediff.github.io/) -[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff) -[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/guoyww/AnimateDiff) - -## Next -One with better controllability and quality is coming soon. Stay tuned. - -## Features -- **[2023/11/10]** Release the Motion Module (beta version) on SDXL, available at [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link -) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt -) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). High resolution videos (i.e., 1024x1024x16 frames with various aspect ratios) could be produced **with/without** personalized models. Inference usually requires ~13GB VRAM and tuned hyperparameters (e.g., #sampling steps), depending on the chosen personalized models. Checkout to the branch `sdxl` for more details of the inference. More checkpoints with better-quality would be available soon. Stay tuned. Examples below are manually downsampled for fast loading. - - - - - - - - - - - - -
Original SDXLPersonalized SDXLPersonalized SDXL
- - - -- **[2023/09/25]** Release **MotionLoRA** and its model zoo, **enabling camera movement controls**! Please download the MotionLoRA models (**74 MB per model**, available at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules) ) and save them to the `models/MotionLoRA` folder. Example: - ``` - python -m scripts.animate --config configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml - ``` - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Zoom InZoom OutZoom Pan LeftZoom Pan Right
Tilt UpTilt DownRolling Anti-ClockwiseRolling Clockwise
- -- **[2023/09/10]** New Motion Module release! `mm_sd_v15_v2.ckpt` was trained on larger resolution & batch size, and gains noticeable quality improvements. Check it out at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules) and use it with `configs/inference/inference-v2.yaml`. Example: - ``` - python -m scripts.animate --config configs/prompts/v2/5-RealisticVision.yaml - ``` - Here is a qualitative comparison between `mm_sd_v15.ckpt` (left) and `mm_sd_v15_v2.ckpt` (right): - - - - - - - - - - - -
-- GPU Memory Optimization, ~12GB VRAM to inference - - -## Quick Demo - -User Interface developed by community: - - A1111 Extension [sd-webui-animatediff](https://github.com/continue-revolution/sd-webui-animatediff) (by [@continue-revolution](https://github.com/continue-revolution)) - - ComfyUI Extension [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) (by [@Kosinkadink](https://github.com/Kosinkadink)) - - Google Colab: [Colab](https://colab.research.google.com/github/camenduru/AnimateDiff-colab/blob/main/AnimateDiff_colab.ipynb) (by [@camenduru](https://github.com/camenduru)) - -We also create a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands: -``` -conda activate animatediff -python app.py -``` -By default, the demo will run at `localhost:7860`. -
+Now you can generate high-resolution videos on SDXL **with/without** personalized models. Checkpoint with better quality would be available soon. Stay tuned. +## Somethings Important +- Generate videos with high-resolution (we provide recommended ones) as SDXL usually leads to worse quality for low-resolution images. +- Follow and slightly adjust the hyperparameters (e.g., #sampling steps, #guidance scale) of various personalized SDXL since these models are carefully tuned to various extent. ## Model Zoo
@@ -123,86 +12,140 @@ By default, the demo will run at `localhost:7860`. | Name | Parameter | Storage Space | |----------------------|-----------|---------------| - | mm_sd_v14.ckpt | 417 M | 1.6 GB | - | mm_sd_v15.ckpt | 417 M | 1.6 GB | - | mm_sd_v15_v2.ckpt | 453 M | 1.7 GB | + | mm_sdxl_v10_beta.ckpt | 238 M | 0.9 GB |
-MotionLoRAs +Recommended Resolution - | Name | Parameter | Storage Space | - |--------------------------------------|-----------|---------------| - | v2_lora_ZoomIn.ckpt | 19 M | 74 MB | - | v2_lora_ZoomOut.ckpt | 19 M | 74 MB | - | v2_lora_PanLeft.ckpt | 19 M | 74 MB | - | v2_lora_PanRight.ckpt | 19 M | 74 MB | - | v2_lora_TiltUp.ckpt | 19 M | 74 MB | - | v2_lora_TiltDown.ckpt | 19 M | 74 MB | - | v2_lora_RollingClockwise.ckpt | 19 M | 74 MB | - | v2_lora_RollingAnticlockwise.ckpt | 19 M | 74 MB | + | Resolution | Aspect Ratio | + |----------------------|-----------| + | 768x1344 | 9:16 | + | 832x1216 | 2:3 | + | 1024x1024 | 1:1 | + | 1216x832 | 3:2 | + | 1344x768 | 16:9 |
-## Common Issues -
-Installation - -Please ensure the installation of [xformer](https://github.com/facebookresearch/xformers) that is applied to reduce the inference memory. -
- - -
-Various resolution or number of frames -Currently, we recommend users to generate animation with 16 frames and 512 resolution that are aligned with our training settings. Notably, various resolution/frames may affect the quality more or less. -
- - -
-How to use it without any coding - -1) Get lora models: train lora model with [A1111](https://github.com/continue-revolution/sd-webui-animatediff) based on a collection of your own favorite images (e.g., tutorials [English](https://www.youtube.com/watch?v=mfaqqL5yOO4), [Japanese](https://www.youtube.com/watch?v=N1tXVR9lplM), [Chinese](https://www.bilibili.com/video/BV1fs4y1x7p2/)) -or download Lora models from [Civitai](https://civitai.com/). - -2) Animate lora models: using gradio interface or A1111 -(e.g., tutorials [English](https://github.com/continue-revolution/sd-webui-animatediff), [Japanese](https://www.youtube.com/watch?v=zss3xbtvOWw), [Chinese](https://941ai.com/sd-animatediff-webui-1203.html)) - -3) Be creative togther with other techniques, such as, super resolution, frame interpolation, music generation, etc. -
- - -
-Animating a given image - -We totally agree that animating a given image is an appealing feature, which we would try to support officially in future. For now, you may enjoy other efforts from the [talesofai](https://github.com/talesofai/AnimateDiff). -
- -
-Contributions from community -Contributions are always welcome!! The dev branch is for community contributions. As for the main branch, we would like to align it with the original technical report :) -
- -## Training and inference -Please refer to [ANIMATEDIFF](./__assets__/docs/animatediff.md) for the detailed setup. - ## Gallery -We collect several generated results in [GALLERY](./__assets__/docs/gallery.md). +We demonstrate some results with our model. The GIFs below are **manually downsampled** after generation for fast loading. + +**Original SDXL** + + + + + +
+ +**LoRA** + + + + + + +
+

Model:DynaVision

+ + + + + + +
+

Model:DreamShaper

+ + + + + + +
+

Model:DeepBlue

+ + +## Inference Example + +Inference at recommended resolution of 16 frames usually requires ~13GB VRAM. +### Step-1: Prepare Environment -## BibTeX ``` -@article{guo2023animatediff, - title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, - author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Wang, Yaohui and Qiao, Yu and Lin, Dahua and Dai, Bo}, - journal={arXiv preprint arXiv:2307.04725}, - year={2023} -} +git clone https://github.com/guoyww/AnimateDiff.git +cd AnimateDiff +git checkout sdxl + + +conda env create -f environment.yaml +conda activate animatediff_xl ``` -## Contact Us -**Yuwei Guo**: [guoyuwei@pjlab.org.cn](mailto:guoyuwei@pjlab.org.cn) -**Ceyuan Yang**: [yangceyuan@pjlab.org.cn](mailto:yangceyuan@pjlab.org.cn) -**Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn) +### Step-2: Download Base T2I & Motion Module Checkpoints +We provide a beta version of motion module on SDXL. You can download the base model of SDXL 1.0 and Motion Module following instructions below. +``` +git lfs install +git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 models/StableDiffusion/ -## Acknowledgements -Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video). +bash download_bashscripts/0-MotionModule.sh +``` +You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link +) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt +) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder. + +### Step-3: Download Personalized SDXL (you can skip this if generating videos on the original SDXL) +You may run the following bash scripts to download the LoRA checkpoint from CivitAI. +``` +bash download_bashscripts/1-DynaVision.sh +bash download_bashscripts/2-DreamShaper.sh +bash download_bashscripts/3-DeepBlue.sh +``` + +### Step-4: Generate Videos +Run the following commands to generate videos of **original SDXL**. +``` +python -m scripts.animate --exp_config configs/prompts/1-original_sdxl.yaml --H 1024 --W 1024 --L 16 --xformers +``` +Run the following commands to generate videos of **personalized SDXL**. DO NOT skip Step-3. +``` +python -m scripts.animate --config configs/prompts/2-DynaVision.yaml --H 1024 --W 1024 --L 16 --xformers +python -m scripts.animate --config configs/prompts/3-DreamShaper.yaml --H 1024 --W 1024 --L 16 --xformers +python -m scripts.animate --config configs/prompts/4-DeepBlue.yaml --H 1024 --W 1024 --L 16 --xformers +``` +The results will automatically be saved to `samples/` folder. + + +## Customized Inference +To generate videos with a new Checkpoint/LoRA model, you may create a new config `.yaml` file in the following format: +``` + +motion_module_path: "models/Motion_Module/mm_sdxl_v10_beta.ckpt" # Specify the Motion Module + +# We support 3 types of T2I models. + # 1. Checkpoint: a safetensors model contains UNet, Text_Encoders, VAE. + # 2. LoRA: a safetensors model contains only the LoRA modules. + # 3. You can convert the Checkpoint into a folder with the same structure as SDXL_1.0 base model. + + +ckpt_path: "YOUR_CKPT_PATH" # path to the checkpoint type model from CivitAI. +lora_path: "YOUR_LORA_PATH" # path to the LORA type model from CivitAI. +base_model_path: "YOUR_BASE_MODEL_PATH" # path to the folder converted from a checkpoint + + +steps: 50 +guidance_scale: 8.5 + +seed: -1 # You can specify seed for each prompt. + +prompt: + - "[positive prompt]" + +n_prompt: + - "[negative prompt]" +``` + +Then run the following commands. +``` +python -m scripts.animate --exp_config [path to the personalized config] --L [video frames] --H [Height of the videos] --W [Width of the videos] --xformers +``` \ No newline at end of file diff --git a/__assets__/animations/compare/new_0.gif b/__assets__/animations/compare/new_0.gif deleted file mode 100644 index 8681fa5..0000000 Binary files a/__assets__/animations/compare/new_0.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_1.gif b/__assets__/animations/compare/new_1.gif deleted file mode 100644 index dd0b296..0000000 Binary files a/__assets__/animations/compare/new_1.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_2.gif b/__assets__/animations/compare/new_2.gif deleted file mode 100644 index 7baeb7b..0000000 Binary files a/__assets__/animations/compare/new_2.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_3.gif b/__assets__/animations/compare/new_3.gif deleted file mode 100644 index 07dc320..0000000 Binary files a/__assets__/animations/compare/new_3.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_0.gif b/__assets__/animations/compare/old_0.gif deleted file mode 100644 index 70709b8..0000000 Binary files a/__assets__/animations/compare/old_0.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_1.gif b/__assets__/animations/compare/old_1.gif deleted file mode 100644 index 5c605be..0000000 Binary files a/__assets__/animations/compare/old_1.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_2.gif b/__assets__/animations/compare/old_2.gif deleted file mode 100644 index 2e20b7b..0000000 Binary files a/__assets__/animations/compare/old_2.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_3.gif b/__assets__/animations/compare/old_3.gif deleted file mode 100644 index b035a95..0000000 Binary files a/__assets__/animations/compare/old_3.gif and /dev/null differ diff --git a/__assets__/animations/model_01/01.gif b/__assets__/animations/model_01/01.gif index 0d869f0..f19dfec 100644 Binary files a/__assets__/animations/model_01/01.gif and b/__assets__/animations/model_01/01.gif differ diff --git a/__assets__/animations/model_01/02.gif b/__assets__/animations/model_01/02.gif index 3924190..480cbc8 100644 Binary files a/__assets__/animations/model_01/02.gif and b/__assets__/animations/model_01/02.gif differ diff --git a/__assets__/animations/model_01/03.gif b/__assets__/animations/model_01/03.gif index 1b1b970..ac3677c 100644 Binary files a/__assets__/animations/model_01/03.gif and b/__assets__/animations/model_01/03.gif differ diff --git a/__assets__/animations/model_01/04.gif b/__assets__/animations/model_01/04.gif deleted file mode 100644 index 593a500..0000000 Binary files a/__assets__/animations/model_01/04.gif and /dev/null differ diff --git a/__assets__/animations/model_02/01.gif b/__assets__/animations/model_02/01.gif index 8d905e7..bfb6ada 100644 Binary files a/__assets__/animations/model_02/01.gif and b/__assets__/animations/model_02/01.gif differ diff --git a/__assets__/animations/model_02/02.gif b/__assets__/animations/model_02/02.gif index 734bee4..5d35b48 100644 Binary files a/__assets__/animations/model_02/02.gif and b/__assets__/animations/model_02/02.gif differ diff --git a/__assets__/animations/model_02/03.gif b/__assets__/animations/model_02/03.gif deleted file mode 100644 index 228bd2a..0000000 Binary files a/__assets__/animations/model_02/03.gif and /dev/null differ diff --git a/__assets__/animations/model_02/04.gif b/__assets__/animations/model_02/04.gif deleted file mode 100644 index fb9f69e..0000000 Binary files a/__assets__/animations/model_02/04.gif and /dev/null differ diff --git a/__assets__/animations/model_03/01.gif b/__assets__/animations/model_03/01.gif index 32f0379..4e54f5f 100644 Binary files a/__assets__/animations/model_03/01.gif and b/__assets__/animations/model_03/01.gif differ diff --git a/__assets__/animations/model_03/02.gif b/__assets__/animations/model_03/02.gif index 42de8ef..44b8827 100644 Binary files a/__assets__/animations/model_03/02.gif and b/__assets__/animations/model_03/02.gif differ diff --git a/__assets__/animations/model_03/03.gif b/__assets__/animations/model_03/03.gif index 28a0446..1f9e2fb 100644 Binary files a/__assets__/animations/model_03/03.gif and b/__assets__/animations/model_03/03.gif differ diff --git a/__assets__/animations/model_03/04.gif b/__assets__/animations/model_03/04.gif deleted file mode 100644 index a920f12..0000000 Binary files a/__assets__/animations/model_03/04.gif and /dev/null differ diff --git a/__assets__/animations/model_04/01.gif b/__assets__/animations/model_04/01.gif deleted file mode 100644 index 492f0cc..0000000 Binary files a/__assets__/animations/model_04/01.gif and /dev/null differ diff --git a/__assets__/animations/model_04/02.gif b/__assets__/animations/model_04/02.gif deleted file mode 100644 index 97bad3c..0000000 Binary files a/__assets__/animations/model_04/02.gif and /dev/null differ diff --git a/__assets__/animations/model_04/03.gif b/__assets__/animations/model_04/03.gif deleted file mode 100644 index 11f1855..0000000 Binary files a/__assets__/animations/model_04/03.gif and /dev/null differ diff --git a/__assets__/animations/model_04/04.gif b/__assets__/animations/model_04/04.gif deleted file mode 100644 index 0e81050..0000000 Binary files a/__assets__/animations/model_04/04.gif and /dev/null differ diff --git a/__assets__/animations/model_05/01.gif b/__assets__/animations/model_05/01.gif deleted file mode 100644 index 782f03a..0000000 Binary files a/__assets__/animations/model_05/01.gif and /dev/null differ diff --git a/__assets__/animations/model_05/02.gif b/__assets__/animations/model_05/02.gif deleted file mode 100644 index 2197447..0000000 Binary files a/__assets__/animations/model_05/02.gif and /dev/null differ diff --git a/__assets__/animations/model_05/03.gif b/__assets__/animations/model_05/03.gif deleted file mode 100644 index 44d922d..0000000 Binary files a/__assets__/animations/model_05/03.gif and /dev/null differ diff --git a/__assets__/animations/model_05/04.gif b/__assets__/animations/model_05/04.gif deleted file mode 100644 index 43e859f..0000000 Binary files a/__assets__/animations/model_05/04.gif and /dev/null differ diff --git a/__assets__/animations/model_06/01.gif b/__assets__/animations/model_06/01.gif deleted file mode 100644 index 123d2a6..0000000 Binary files a/__assets__/animations/model_06/01.gif and /dev/null differ diff --git a/__assets__/animations/model_06/02.gif b/__assets__/animations/model_06/02.gif deleted file mode 100644 index fd64b6e..0000000 Binary files a/__assets__/animations/model_06/02.gif and /dev/null differ diff --git a/__assets__/animations/model_06/03.gif b/__assets__/animations/model_06/03.gif deleted file mode 100644 index 53cbb28..0000000 Binary files a/__assets__/animations/model_06/03.gif and /dev/null differ diff --git a/__assets__/animations/model_06/04.gif b/__assets__/animations/model_06/04.gif deleted file mode 100644 index ddd0e01..0000000 Binary files a/__assets__/animations/model_06/04.gif and /dev/null differ diff --git a/__assets__/animations/model_07/01.gif b/__assets__/animations/model_07/01.gif deleted file mode 100644 index be0eaf4..0000000 Binary files a/__assets__/animations/model_07/01.gif and /dev/null differ diff --git a/__assets__/animations/model_07/02.gif b/__assets__/animations/model_07/02.gif deleted file mode 100644 index bbcad05..0000000 Binary files a/__assets__/animations/model_07/02.gif and /dev/null differ diff --git a/__assets__/animations/model_07/03.gif b/__assets__/animations/model_07/03.gif deleted file mode 100644 index 447b70b..0000000 Binary files a/__assets__/animations/model_07/03.gif and /dev/null differ diff --git a/__assets__/animations/model_07/04.gif b/__assets__/animations/model_07/04.gif deleted file mode 100644 index 2c175ff..0000000 Binary files a/__assets__/animations/model_07/04.gif and /dev/null differ diff --git a/__assets__/animations/model_07/init.jpg b/__assets__/animations/model_07/init.jpg deleted file mode 100644 index 3b83812..0000000 Binary files a/__assets__/animations/model_07/init.jpg and /dev/null differ diff --git a/__assets__/animations/model_08/01.gif b/__assets__/animations/model_08/01.gif deleted file mode 100644 index e183a8b..0000000 Binary files a/__assets__/animations/model_08/01.gif and /dev/null differ diff --git a/__assets__/animations/model_08/02.gif b/__assets__/animations/model_08/02.gif deleted file mode 100644 index 326ef26..0000000 Binary files a/__assets__/animations/model_08/02.gif and /dev/null differ diff --git a/__assets__/animations/model_08/03.gif b/__assets__/animations/model_08/03.gif deleted file mode 100644 index cb0c4a8..0000000 Binary files a/__assets__/animations/model_08/03.gif and /dev/null differ diff --git a/__assets__/animations/model_08/04.gif b/__assets__/animations/model_08/04.gif deleted file mode 100644 index a062247..0000000 Binary files a/__assets__/animations/model_08/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/01.gif b/__assets__/animations/model_original/01.gif similarity index 100% rename from __assets__/animations/motion_xl/01.gif rename to __assets__/animations/model_original/01.gif diff --git a/__assets__/animations/model_original/02.gif b/__assets__/animations/model_original/02.gif new file mode 100644 index 0000000..57bf14e Binary files /dev/null and b/__assets__/animations/model_original/02.gif differ diff --git a/__assets__/animations/motion_lora/model_01/01.gif b/__assets__/animations/motion_lora/model_01/01.gif deleted file mode 100644 index a4e85ea..0000000 Binary files a/__assets__/animations/motion_lora/model_01/01.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/02.gif b/__assets__/animations/motion_lora/model_01/02.gif deleted file mode 100644 index 5b33305..0000000 Binary files a/__assets__/animations/motion_lora/model_01/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/03.gif b/__assets__/animations/motion_lora/model_01/03.gif deleted file mode 100644 index 16a5ae2..0000000 Binary files a/__assets__/animations/motion_lora/model_01/03.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/04.gif b/__assets__/animations/motion_lora/model_01/04.gif deleted file mode 100644 index 73b1d6a..0000000 Binary files a/__assets__/animations/motion_lora/model_01/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/05.gif b/__assets__/animations/motion_lora/model_01/05.gif deleted file mode 100644 index 9fb3661..0000000 Binary files a/__assets__/animations/motion_lora/model_01/05.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/06.gif b/__assets__/animations/motion_lora/model_01/06.gif deleted file mode 100644 index aa18797..0000000 Binary files a/__assets__/animations/motion_lora/model_01/06.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/07.gif b/__assets__/animations/motion_lora/model_01/07.gif deleted file mode 100644 index 8308862..0000000 Binary files a/__assets__/animations/motion_lora/model_01/07.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/08.gif b/__assets__/animations/motion_lora/model_01/08.gif deleted file mode 100644 index 75ba8fa..0000000 Binary files a/__assets__/animations/motion_lora/model_01/08.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/01.gif b/__assets__/animations/motion_lora/model_02/01.gif deleted file mode 100644 index 36db48e..0000000 Binary files a/__assets__/animations/motion_lora/model_02/01.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/02.gif b/__assets__/animations/motion_lora/model_02/02.gif deleted file mode 100644 index ead0fd4..0000000 Binary files a/__assets__/animations/motion_lora/model_02/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/03.gif b/__assets__/animations/motion_lora/model_02/03.gif deleted file mode 100644 index 1b5136b..0000000 Binary files a/__assets__/animations/motion_lora/model_02/03.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/04.gif b/__assets__/animations/motion_lora/model_02/04.gif deleted file mode 100644 index b409fd9..0000000 Binary files a/__assets__/animations/motion_lora/model_02/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/05.gif b/__assets__/animations/motion_lora/model_02/05.gif deleted file mode 100644 index 5216871..0000000 Binary files a/__assets__/animations/motion_lora/model_02/05.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/06.gif b/__assets__/animations/motion_lora/model_02/06.gif deleted file mode 100644 index 25c7ee7..0000000 Binary files a/__assets__/animations/motion_lora/model_02/06.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/07.gif b/__assets__/animations/motion_lora/model_02/07.gif deleted file mode 100644 index 3b239aa..0000000 Binary files a/__assets__/animations/motion_lora/model_02/07.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/08.gif b/__assets__/animations/motion_lora/model_02/08.gif deleted file mode 100644 index d1348e0..0000000 Binary files a/__assets__/animations/motion_lora/model_02/08.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/02.gif b/__assets__/animations/motion_xl/02.gif deleted file mode 100644 index f19dfec..0000000 Binary files a/__assets__/animations/motion_xl/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/03.gif b/__assets__/animations/motion_xl/03.gif deleted file mode 100644 index 4e54f5f..0000000 Binary files a/__assets__/animations/motion_xl/03.gif and /dev/null differ diff --git a/__assets__/docs/animatediff.md b/__assets__/docs/animatediff.md deleted file mode 100644 index 6e1f26b..0000000 --- a/__assets__/docs/animatediff.md +++ /dev/null @@ -1,112 +0,0 @@ -# AnimateDiff: training and inference setup -## Setups for Inference - -### Prepare Environment - -***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!*** - -``` -git clone https://github.com/guoyww/AnimateDiff.git -cd AnimateDiff - -conda env create -f environment.yaml -conda activate animatediff -``` - -### Download Base T2I & Motion Module Checkpoints -We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately. -It's recommanded to try both of them for best results. -``` -git lfs install -git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/ - -bash download_bashscripts/0-MotionModule.sh -``` -You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder. - -### Prepare Personalize T2I -Here we provide inference configs for 6 demo T2I on CivitAI. -You may run the following bash scripts to download these checkpoints. -``` -bash download_bashscripts/1-ToonYou.sh -bash download_bashscripts/2-Lyriel.sh -bash download_bashscripts/3-RcnzCartoon.sh -bash download_bashscripts/4-MajicMix.sh -bash download_bashscripts/5-RealisticVision.sh -bash download_bashscripts/6-Tusun.sh -bash download_bashscripts/7-FilmVelvia.sh -bash download_bashscripts/8-GhibliBackground.sh -``` - -### Inference -After downloading the above peronalized T2I checkpoints, run the following commands to generate animations. The results will automatically be saved to `samples/` folder. -``` -python -m scripts.animate --config configs/prompts/1-ToonYou.yaml -python -m scripts.animate --config configs/prompts/2-Lyriel.yaml -python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml -python -m scripts.animate --config configs/prompts/4-MajicMix.yaml -python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml -python -m scripts.animate --config configs/prompts/6-Tusun.yaml -python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml -python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml -``` - -To generate animations with a new DreamBooth/LoRA model, you may create a new config `.yaml` file in the following format: -``` -NewModel: - inference_config: "[path to motion module config file]" - - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - motion_module_lora_configs: - - path: "[path to MotionLoRA model]" - alpha: 1.0 - - ... - - dreambooth_path: "[path to your DreamBooth model .safetensors file]" - lora_model_path: "[path to your LoRA model .safetensors file, leave it empty string if not needed]" - - steps: 25 - guidance_scale: 7.5 - - prompt: - - "[positive prompt]" - - n_prompt: - - "[negative prompt]" -``` -Then run the following commands: -``` -python -m scripts.animate --config [path to the config file] -``` - - -## Steps for Training - -### Dataset -Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine. -Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`. - -### Configuration -After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder: -``` -train_data: - csv_path: [Replace with .csv Annotation File Path] - video_folder: [Replace with Video Folder Path] - sample_size: 256 -``` -Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files. - -### Training -To train motion modules -``` -torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml -``` - -To finetune the unet's image layers -``` -torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml -``` - diff --git a/__assets__/docs/gallery.md b/__assets__/docs/gallery.md deleted file mode 100644 index 8891dd2..0000000 --- a/__assets__/docs/gallery.md +++ /dev/null @@ -1,93 +0,0 @@ -# Gallery -Here we demonstrate several best results we found in our experiments. - - - - - - - - -
-

Model:ToonYou

- - - - - - - - -
-

Model:Counterfeit V3.0

- - - - - - - - -
-

Model:Realistic Vision V2.0

- - - - - - - - -
-

Model: majicMIX Realistic

- - - - - - - - -
-

Model:RCNZ Cartoon

- - - - - - - - -
-

Model:FilmVelvia

- -#### Community Cases -Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚. - - - - - - - - - -
-

-Character Model:Yoimiya -(with an initial reference image, see WIP fork for the extended implementation.) - - - - - - - - - -
-

-Character Model:Paimon; -Pose Model:Hold Sign

- - diff --git a/__assets__/figs/gradio.jpg b/__assets__/figs/gradio.jpg deleted file mode 100644 index 19aea7c..0000000 Binary files a/__assets__/figs/gradio.jpg and /dev/null differ diff --git a/animatediff/.DS_Store b/animatediff/.DS_Store new file mode 100644 index 0000000..5fd8fa5 Binary files /dev/null and b/animatediff/.DS_Store differ diff --git a/animatediff/data/dataset.py b/animatediff/data/dataset.py deleted file mode 100644 index 3f6ec10..0000000 --- a/animatediff/data/dataset.py +++ /dev/null @@ -1,98 +0,0 @@ -import os, io, csv, math, random -import numpy as np -from einops import rearrange -from decord import VideoReader - -import torch -import torchvision.transforms as transforms -from torch.utils.data.dataset import Dataset -from animatediff.utils.util import zero_rank_print - - - -class WebVid10M(Dataset): - def __init__( - self, - csv_path, video_folder, - sample_size=256, sample_stride=4, sample_n_frames=16, - is_image=False, - ): - zero_rank_print(f"loading annotations from {csv_path} ...") - with open(csv_path, 'r') as csvfile: - self.dataset = list(csv.DictReader(csvfile)) - self.length = len(self.dataset) - zero_rank_print(f"data scale: {self.length}") - - self.video_folder = video_folder - self.sample_stride = sample_stride - self.sample_n_frames = sample_n_frames - self.is_image = is_image - - sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) - self.pixel_transforms = transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.Resize(sample_size[0]), - transforms.CenterCrop(sample_size), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ]) - - def get_batch(self, idx): - video_dict = self.dataset[idx] - videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] - - video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") - video_reader = VideoReader(video_dir) - video_length = len(video_reader) - - if not self.is_image: - clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) - start_idx = random.randint(0, video_length - clip_length) - batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) - else: - batch_index = [random.randint(0, video_length - 1)] - - pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() - pixel_values = pixel_values / 255. - del video_reader - - if self.is_image: - pixel_values = pixel_values[0] - - return pixel_values, name - - def __len__(self): - return self.length - - def __getitem__(self, idx): - while True: - try: - pixel_values, name = self.get_batch(idx) - break - - except Exception as e: - idx = random.randint(0, self.length-1) - - pixel_values = self.pixel_transforms(pixel_values) - sample = dict(pixel_values=pixel_values, text=name) - return sample - - - -if __name__ == "__main__": - from animatediff.utils.util import save_videos_grid - - dataset = WebVid10M( - csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", - video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", - sample_size=256, - sample_stride=4, sample_n_frames=16, - is_image=True, - ) - import pdb - pdb.set_trace() - - dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) - for idx, batch in enumerate(dataloader): - print(batch["pixel_values"].shape, len(batch["text"])) - # for i in range(batch["pixel_values"].shape[0]): - # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py deleted file mode 100644 index ad23583..0000000 --- a/animatediff/models/attention.py +++ /dev/null @@ -1,300 +0,0 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py - -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin -from diffusers.utils import BaseOutput -from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm - -from einops import rearrange, repeat -import pdb - -@dataclass -class Transformer3DModelOutput(BaseOutput): - sample: torch.FloatTensor - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class Transformer3DModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # Define input layers - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - # Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): - # Input - assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." - video_length = hidden_states.shape[2] - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) - - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) - - # Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - video_length=video_length - ) - - # Output - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - - output = hidden_states + residual - - output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) - if not return_dict: - return (output,) - - return Transformer3DModelOutput(sample=output) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm = num_embeds_ada_norm is not None - self.unet_use_cross_frame_attention = unet_use_cross_frame_attention - self.unet_use_temporal_attention = unet_use_temporal_attention - - # SC-Attn - assert unet_use_cross_frame_attention is not None - if unet_use_cross_frame_attention: - self.attn1 = SparseCausalAttention2D( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - else: - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - # Cross-Attn - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - else: - self.attn2 = None - - if cross_attention_dim is not None: - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - else: - self.norm2 = None - - # Feed-forward - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) - self.norm3 = nn.LayerNorm(dim) - - # Temp-Attn - assert unet_use_temporal_attention is not None - if unet_use_temporal_attention: - self.attn_temp = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - nn.init.zeros_(self.attn_temp.to_out[0].weight.data) - self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if not is_xformers_available(): - print("Here is how to install it") - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" - " available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - if self.attn2 is not None: - self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): - # SparseCausal-Attention - norm_hidden_states = ( - self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) - ) - - # if self.only_cross_attention: - # hidden_states = ( - # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states - # ) - # else: - # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states - - # pdb.set_trace() - if self.unet_use_cross_frame_attention: - hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states - else: - hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states - - if self.attn2 is not None: - # Cross-Attention - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - hidden_states = ( - self.attn2( - norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask - ) - + hidden_states - ) - - # Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - - # Temporal-Attention - if self.unet_use_temporal_attention: - d = hidden_states.shape[1] - hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) - norm_hidden_states = ( - self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) - ) - hidden_states = self.attn_temp(norm_hidden_states) + hidden_states - hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) - - return hidden_states diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 2359e71..2638a50 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import numpy as np @@ -8,324 +8,418 @@ from torch import nn import torchvision from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.attention import FeedForward + +from animatediff.utils.util import zero_rank_print from einops import rearrange, repeat -import math +import math, pdb +import random def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module @dataclass class TemporalTransformer3DModelOutput(BaseOutput): - sample: torch.FloatTensor - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None + sample: torch.FloatTensor def get_motion_module( - in_channels, - motion_module_type: str, - motion_module_kwargs: dict + in_channels, + motion_module_type: str, + motion_module_kwargs: dict ): - if motion_module_type == "Vanilla": - return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) - else: - raise ValueError - + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs) + elif motion_module_type == "Conv": + return ConvTemporalModule(in_channels=in_channels, **motion_module_kwargs) + else: + raise ValueError class VanillaTemporalModule(nn.Module): - def __init__( - self, - in_channels, - num_attention_heads = 8, - num_transformer_block = 2, - attention_block_types =( "Temporal_Self", "Temporal_Self" ), - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - temporal_attention_dim_div = 1, - zero_initialize = True, - ): - super().__init__() - - self.temporal_transformer = TemporalTransformer3DModel( - in_channels=in_channels, - num_attention_heads=num_attention_heads, - attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, - num_layers=num_transformer_block, - attention_block_types=attention_block_types, - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - - if zero_initialize: - self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", ), + spatial_position_encoding = False, + temporal_position_encoding = True, + temporal_position_encoding_max_len = 32, + temporal_attention_dim_div = 1, + zero_initialize = True, + + causal_temporal_attention = False, + causal_temporal_attention_mask_type = "", + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + spatial_position_encoding = spatial_position_encoding, + causal_temporal_attention=causal_temporal_attention, + causal_temporal_attention_mask_type=causal_temporal_attention_mask_type, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) - def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): - hidden_states = input_tensor - hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) + def forward(self, input_tensor, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) - output = hidden_states - return output + output = hidden_states + return output -class TemporalTransformer3DModel(nn.Module): - def __init__( - self, - in_channels, - num_attention_heads, - attention_head_dim, +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + spatial_position_encoding = False, + + causal_temporal_attention = None, + causal_temporal_attention_mask_type = "", + ): + super().__init__() + assert causal_temporal_attention is not None + self.causal_temporal_attention = causal_temporal_attention - num_layers, - attention_block_types = ( "Temporal_Self", "Temporal_Self", ), - dropout = 0.0, - norm_num_groups = 32, - cross_attention_dim = 768, - activation_fn = "geglu", - attention_bias = False, - upcast_attention = False, - - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - ): - super().__init__() + assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "") + self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type + self.causal_temporal_attention_mask = None + self.spatial_position_encoding = spatial_position_encoding + inner_dim = num_attention_heads * attention_head_dim - inner_dim = num_attention_heads * attention_head_dim + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + if spatial_position_encoding: + self.pos_encoder_2d = PositionalEncoding2D(inner_dim) + - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def get_causal_temporal_attention_mask(self, hidden_states): + batch_size, sequence_length, dim = hidden_states.shape + + if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (batch_size, sequence_length, sequence_length): + zero_rank_print(f"build attn mask of type {self.causal_temporal_attention_mask_type}") + if self.causal_temporal_attention_mask_type == "causal": + # 1. vanilla causal mask + mask = torch.tril(torch.ones(sequence_length, sequence_length)) - self.transformer_blocks = nn.ModuleList( - [ - TemporalTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - attention_block_types=attention_block_types, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - upcast_attention=upcast_attention, - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - for d in range(num_layers) - ] - ) - self.proj_out = nn.Linear(inner_dim, in_channels) - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): - assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." - video_length = hidden_states.shape[2] - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + elif self.causal_temporal_attention_mask_type == "2-seq": + # 2. 2-seq + mask = torch.zeros(sequence_length, sequence_length) + mask[:sequence_length // 2, :sequence_length // 2] = 1 + mask[-sequence_length // 2:, -sequence_length // 2:] = 1 + + elif self.causal_temporal_attention_mask_type == "0-prev": + # attn to the previous frame + indices = torch.arange(sequence_length) + indices_prev = indices - 1 + indices_prev[0] = 0 + mask = torch.zeros(sequence_length, sequence_length) + mask[:, 0] = 1. + mask[indices, indices_prev] = 1. - batch, channel, height, weight = hidden_states.shape - residual = hidden_states + elif self.causal_temporal_attention_mask_type == "0": + # only attn to first frame + mask = torch.zeros(sequence_length, sequence_length) + mask[:,0] = 1 - hidden_states = self.norm(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) + elif self.causal_temporal_attention_mask_type == "wo-self": + indices = torch.arange(sequence_length) + mask = torch.ones(sequence_length, sequence_length) + mask[indices, indices] = 0 - # Transformer Blocks - for block in self.transformer_blocks: - hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) - - # output - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + elif self.causal_temporal_attention_mask_type == "circle": + indices = torch.arange(sequence_length) + indices_prev = indices - 1 + indices_prev[0] = 0 - output = hidden_states + residual - output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) - - return output + mask = torch.eye(sequence_length) + mask[indices, indices_prev] = 1 + mask[0,-1] = 1 + else: raise ValueError + + # for sanity check + if dim == 320: zero_rank_print(mask) + + # generate attention mask fron binary values + mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + mask = mask.unsqueeze(0) + mask = mask.repeat(batch_size, 1, 1) + + self.causal_temporal_attention_mask = mask.to(hidden_states.device) + + return self.causal_temporal_attention_mask + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + height, width = hidden_states.shape[-2:] + + hidden_states = self.norm(hidden_states) + + hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c") + hidden_states = self.proj_in(hidden_states) + if self.spatial_position_encoding: + + video_length = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b h w) f c -> (b f) h w c", h=height, w=width) + pos_encoding = self.pos_encoder_2d(hidden_states) + pos_encoding = rearrange(pos_encoding, "(b f) h w c -> (b h w) f c", f = video_length) + hidden_states = rearrange(hidden_states, "(b f) h w c -> (b h w) f c", f=video_length) + + attention_mask = self.get_causal_temporal_attention_mask(hidden_states) if self.causal_temporal_attention else attention_mask + + # Transformer Blocks + for block in self.transformer_blocks: + if not self.spatial_position_encoding : + pos_encoding = None + + hidden_states = block(hidden_states, pos_encoding=pos_encoding, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width) + + output = hidden_states + residual + # output = hidden_states + + return output class TemporalTransformerBlock(nn.Module): - def __init__( - self, - dim, - num_attention_heads, - attention_head_dim, - attention_block_types = ( "Temporal_Self", "Temporal_Self", ), - dropout = 0.0, - norm_num_groups = 32, - cross_attention_dim = 768, - activation_fn = "geglu", - attention_bias = False, - upcast_attention = False, - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - ): - super().__init__() + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + ): + super().__init__() - attention_blocks = [] - norms = [] - - for block_name in attention_block_types: - attention_blocks.append( - VersatileAttention( - attention_mode=block_name.split("_")[0], - cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, - - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - ) - norms.append(nn.LayerNorm(dim)) - - self.attention_blocks = nn.ModuleList(attention_blocks) - self.norms = nn.ModuleList(norms) - - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) - self.ff_norm = nn.LayerNorm(dim) + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + TemporalSelfAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): - for attention_block, norm in zip(self.attention_blocks, self.norms): - norm_hidden_states = norm(hidden_states) - hidden_states = attention_block( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, - video_length=video_length, - ) + hidden_states - - hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states - - output = hidden_states - return output + def forward(self, hidden_states, pos_encoding=None, encoder_hidden_states=None, attention_mask=None): + for attention_block, norm in zip(self.attention_blocks, self.norms): + if pos_encoding is not None: + hidden_states += pos_encoding + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + hidden_states + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +def get_emb(sin_inp): + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + +class PositionalEncoding2D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding2D, self).__init__() + self.org_channels = channels + channels = int(np.ceil(channels / 4) * 2) + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None) + + def forward(self, tensor): + """ + :param tensor: A 4d tensor of size (batch_size, x, y, ch) + :return: Positional Encoding Matrix of size (batch_size, x, y, ch) + """ + if len(tensor.shape) != 4: + raise RuntimeError("The input tensor has to be 4d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, y, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + emb_x = get_emb(sin_inp_x).unsqueeze(1) + emb_y = get_emb(sin_inp_y) + emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( + tensor.type() + ) + emb[:, :, : self.channels] = emb_x + emb[:, :, self.channels : 2 * self.channels] = emb_y + + self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) + return self.cached_penc class PositionalEncoding(nn.Module): - def __init__( - self, - d_model, - dropout = 0., - max_len = 24 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(1, max_len, d_model) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + def __init__( + self, + d_model, + dropout = 0., + max_len = 32, + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return self.dropout(x) + def forward(self, x): + # if x.size(1) < 16: + # start_idx = random.randint(0, 12) + # else: + # start_idx = 0 + + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) -class VersatileAttention(CrossAttention): - def __init__( - self, - attention_mode = None, - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - *args, **kwargs - ): - super().__init__(*args, **kwargs) - assert attention_mode == "Temporal" +class TemporalSelfAttention(Attention): + def __init__( + self, + attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" - self.attention_mode = attention_mode - self.is_cross_attention = kwargs["cross_attention_dim"] is not None - - self.pos_encoder = PositionalEncoding( - kwargs["query_dim"], - dropout=0., - max_len=temporal_position_encoding_max_len - ) if (temporal_position_encoding and attention_mode == "Temporal") else None + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + max_len=temporal_position_encoding_max_len + ) if temporal_position_encoding else None - def extra_repr(self): - return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + # disable motion module efficient xformers to avoid bad results, don't know why + # TODO: fix this bug + pass - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): - batch_size, sequence_length, _ = hidden_states.shape + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty - if self.attention_mode == "Temporal": - d = hidden_states.shape[1] - hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) - - if self.pos_encoder is not None: - hidden_states = self.pos_encoder(hidden_states) - - encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states - else: - raise NotImplementedError + # add position encoding + hidden_states = self.pos_encoder(hidden_states) - encoder_hidden_states = encoder_hidden_states + if hasattr(self.processor, "__call__"): + return self.processor.__call__( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) - - if self.added_kv_proj_dim is not None: - raise NotImplementedError - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if attention_mask is not None: - if attention_mask.shape[-1] != query.shape[1]: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value, attention_mask) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - - # dropout - hidden_states = self.to_out[1](hidden_states) - - if self.attention_mode == "Temporal": - hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) - - return hidden_states + else: + return self.processor( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) diff --git a/animatediff/models/resnet.py b/animatediff/models/resnet.py deleted file mode 100644 index da80f17..0000000 --- a/animatediff/models/resnet.py +++ /dev/null @@ -1,217 +0,0 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange - - -class InflatedConv3d(nn.Conv2d): - def forward(self, x): - video_length = x.shape[2] - - x = rearrange(x, "b c f h w -> (b f) c h w") - x = super().forward(x) - x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) - - return x - - -class InflatedGroupNorm(nn.GroupNorm): - def forward(self, x): - video_length = x.shape[2] - - x = rearrange(x, "b c f h w -> (b f) c h w") - x = super().forward(x) - x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) - - return x - - -class Upsample3D(nn.Module): - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - raise NotImplementedError - elif use_conv: - self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - raise NotImplementedError - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - # if self.use_conv: - # if self.name == "conv": - # hidden_states = self.conv(hidden_states) - # else: - # hidden_states = self.Conv2d_0(hidden_states) - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class Downsample3D(nn.Module): - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - raise NotImplementedError - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - raise NotImplementedError - - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - output_scale_factor=1.0, - use_in_shortcut=None, - use_inflated_groupnorm=None, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - assert use_inflated_groupnorm != None - if use_inflated_groupnorm: - self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels - elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 - else: - raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") - - self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) - else: - self.time_emb_proj = None - - if use_inflated_groupnorm: - self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 18aa955..0ee1aea 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -1,41 +1,157 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py - +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - +from typing import Any, Dict, List, Optional, Tuple, Union import os import json -import pdb + import torch import torch.nn as nn import torch.utils.checkpoint +from einops import rearrange, repeat from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers.loaders import UNet2DConditionLoadersMixin, AttnProcsLayers from diffusers.utils import BaseOutput, logging -from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, LoRAAttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + PositionNet, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin from .unet_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) -from .resnet import InflatedConv3d, InflatedGroupNorm +from animatediff.utils.util import zero_rank_print logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): - sample: torch.FloatTensor + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + _supports_gradient_checkpointing = True @register_to_config @@ -46,114 +162,308 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, - freq_shift: int = 0, + freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - mid_block_type: str = "UNetMidBlock3DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" - ), + mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: int = 32, + norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - - use_inflated_groupnorm=False, - - # Additional - use_motion_module = False, - motion_module_resolutions = ( 1,2,4,8 ), - motion_module_mid_block = False, - motion_module_decoder_only = False, - motion_module_type = None, - motion_module_kwargs = {}, - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + + # motion module + use_motion_module=False, + motion_module_resolutions = (1,2,4,8), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type=None, + motion_module_kwargs=None, ): super().__init__() - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) # input - self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + self.down_blocks = nn.ModuleList([]) - self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): - res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - + res = 2 ** i down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -163,46 +473,48 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) + elif mid_block_type is None: + self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the videos + + # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): - res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 - + res = 2 ** (2 - i) prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - + # add upsample block for all BUT final layer if not is_final_block: add_upsample = True @@ -212,27 +524,28 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -241,41 +554,215 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): prev_output_channel = output_channel # out - if use_inflated_groupnorm: - self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type == "gated": + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim) + + def set_image_layer_lora(self, image_layer_lora_rank: int = 128): + lora_attn_procs = {} + for name in self.attn_processors.keys(): + zero_rank_print(f"(add lora) {name}") + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=image_layer_lora_rank if image_layer_lora_rank > 16 else hidden_size // image_layer_lora_rank, + ) + self.set_attn_processor(lora_attn_procs) + + lora_layers = AttnProcsLayers(self.attn_processors) + zero_rank_print(f"(lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M") + del lora_layers + + def set_image_layer_lora_scale(self, lora_scale: float = 1.0): + for block in self.down_blocks: setattr(block, "lora_scale", lora_scale) + for block in self.up_blocks: setattr(block, "lora_scale", lora_scale) + setattr(self.mid_block, "lora_scale", lora_scale) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + if not "motion_modules." in name: + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], is_motion_module=False): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) if not is_motion_module else len(self.motion_module_attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if ((not is_motion_module) and (not "motion_modules." in name)) or (is_motion_module and ("motion_modules." in name)): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + @property + def motion_module_attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + # filter out processors in motion module + if hasattr(module, "set_processor"): + if "motion_modules." in name: + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_motion_module_lora(self, motion_module_lora_rank: int = 256, motion_lora_resolution=[32, 64, 128]): + lora_attn_procs = {} + #motion_name = [] + #if 32 in motion_lora_resolution: + # motion_name.append('up_blocks.0') + # motion_name.append('down_blocks.2') + # if 64 in motion_lora_resolution: + # motion_name.append('up_blocks.1') + # motion_name.append('down_blocks.1') + # if 128 in motion_lora_resolution: + # motion_name.append('up_blocks.2') + # motion_name.append('down_blocks.0') + for name in self.motion_module_attn_processors.keys(): + #prefix = '.'.join(name.split('.')[:2]) + #if prefix not in motion_name: + # continue + print(f"(add motion lora) {name}") + + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=None, + rank=motion_module_lora_rank, + ) + self.set_attn_processor(lora_attn_procs, is_motion_module=True) + + lora_layers = AttnProcsLayers(self.motion_module_attn_processors) + print(f"(motion lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M") + del lora_layers + + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) + fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) + fn_recursive_retrieve_sliceable_dims(module) - num_slicable_layers = len(sliceable_head_dims) + num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between @@ -283,9 +770,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible - slice_size = num_slicable_layers * [1] + slice_size = num_sliceable_layers * [1] - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -314,7 +801,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( @@ -323,28 +810,57 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" + The [`UNet2DConditionModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers + # convert the time, size, and text embedding into (b f) c h w + video_length = sample.shape[2] + timestep = repeat(timestep, "b-> (b f)", f=video_length) + encoder_hidden_states = repeat(encoder_hidden_states, "b c d-> (b f) c d", f=video_length) + added_cond_kwargs['time_ids'] = repeat(added_cond_kwargs['time_ids'], "b c -> (b f) c", f=video_length) + added_cond_kwargs['text_embeds'] = repeat(added_cond_kwargs['text_embeds'], "b c -> (b f) c", f=video_length) + + #sample = rearrange(sample, "b c f h w -> (b f) c h w") + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None @@ -353,18 +869,35 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # center input if necessary + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 - # time + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): @@ -376,15 +909,17 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # timesteps = timesteps t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -393,33 +928,167 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) - # pre-process + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + video_length = sample.shape[2] + sample = rearrange(sample, "b c f h w -> (b f) c h w") sample = self.conv_in(sample) + sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None - # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples - # mid - sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask - ) + if is_controlnet: + new_down_block_res_samples = () - # up + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -437,24 +1106,32 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) + + video_length = sample.shape[2] + sample = rearrange(sample, "b c f h w -> (b f) c h w") + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) sample = self.conv_out(sample) + sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) - + @classmethod def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): if subfolder is not None: @@ -468,24 +1145,29 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ + "DownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "DownBlock3D" + ] config["up_block_types"] = [ + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" ] - from diffusers.utils import WEIGHTS_NAME + config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + from diffusers.utils import SAFETENSORS_WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + + state_dict = {} + from safetensors import safe_open + with safe_open(model_file, framework='pt') as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") @@ -494,4 +1176,4 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") - return model + return model \ No newline at end of file diff --git a/animatediff/models/unet_blocks.py b/animatediff/models/unet_blocks.py index 711ad6c..c7e397a 100644 --- a/animatediff/models/unet_blocks.py +++ b/animatediff/models/unet_blocks.py @@ -1,13 +1,20 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +from typing import Any, Dict, Optional, Tuple +import numpy as np import torch +import torch.nn.functional as F +from einops import rearrange from torch import nn -from .attention import Transformer3DModel -from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from diffusers.utils import is_torch_version, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel from .motion_module import get_motion_module -import pdb +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, @@ -18,7 +25,8 @@ def get_down_block( add_downsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, @@ -27,16 +35,23 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, use_motion_module=None, - motion_module_type=None, motion_module_kwargs=None, ): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock3D": return DownBlock3D( @@ -50,18 +65,16 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, - - use_inflated_groupnorm=use_inflated_groupnorm, - use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock3D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -71,17 +84,13 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -99,7 +108,8 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, @@ -107,15 +117,23 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock3D": return UpBlock3D( @@ -129,18 +147,16 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, - - use_inflated_groupnorm=use_inflated_groupnorm, - use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock3D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -150,23 +166,19 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) - raise ValueError(f"{up_block_type} does not exist.") + raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( @@ -175,36 +187,32 @@ class UNetMidBlock3DCrossAttn(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", use_motion_module=None, - motion_module_type=None, motion_module_kwargs=None, ): super().__init__() self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels + self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -215,40 +223,46 @@ class UNetMidBlock3DCrossAttn(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ] attentions = [] motion_modules = [] for _ in range(num_layers): - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) motion_modules.append( get_motion_module( in_channels=in_channels, - motion_module_type=motion_module_type, + motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -259,8 +273,6 @@ class UNetMidBlock3DCrossAttn(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) @@ -268,121 +280,22 @@ class UNetMidBlock3DCrossAttn(nn.Module): self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class CrossAttnDownBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - - use_motion_module=None, - - motion_module_type=None, - motion_module_kwargs=None, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, - ) - ) - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - ) - motion_modules.append( - get_motion_module( - in_channels=out_channels, - motion_module_type=motion_module_type, - motion_module_kwargs=motion_module_kwargs, - ) if use_motion_module else None - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - output_states = () - - for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + for attn, motion_module, resnet in zip(self.attentions, self.motion_modules, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -393,33 +306,46 @@ class CrossAttnDownBlock3D(nn.Module): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( hidden_states, - encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) - + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, temb, + encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - # add motion module + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - + return hidden_states class DownBlock3D(nn.Module): def __init__( @@ -437,9 +363,6 @@ class DownBlock3D(nn.Module): output_scale_factor=1.0, add_downsample=True, downsample_padding=1, - - use_inflated_groupnorm=None, - use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, @@ -451,7 +374,7 @@ class DownBlock3D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -462,8 +385,6 @@ class DownBlock3D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( @@ -473,14 +394,14 @@ class DownBlock3D(nn.Module): motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) - + self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ - Downsample3D( + Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] @@ -495,178 +416,45 @@ class DownBlock3D(nn.Module): for resnet, motion_module in zip(self.resnets, self.motion_modules): if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, temb, encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states) if motion_module is not None else hidden_states - # add motion module - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = downsampler(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states -class CrossAttnUpBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - - use_motion_module=None, - - motion_module_type=None, - motion_module_kwargs=None, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock3D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, - ) - ) - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - ) - motion_modules.append( - get_motion_module( - in_channels=out_channels, - motion_module_type=motion_module_type, - motion_module_kwargs=motion_module_kwargs, - ) if use_motion_module else None - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - ): - for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - )[0] - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - # add motion module - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - class UpBlock3D(nn.Module): def __init__( self, @@ -683,9 +471,6 @@ class UpBlock3D(nn.Module): resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, - - use_inflated_groupnorm=None, - use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, @@ -699,7 +484,7 @@ class UpBlock3D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -710,8 +495,6 @@ class UpBlock3D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( @@ -726,35 +509,405 @@ class UpBlock3D(nn.Module): self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): - for resnet, motion_module in zip(self.resnets, self.motion_modules): + def forward(self, hidden_states, res_hidden_states_tuple, encoder_hidden_states=None, temb=None, upsample_size=None): + for (resnet, motion_module) in zip(self.resnets, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, + temb, encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, temb, + encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = downsampler(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, temb, encoder_hidden_states, + use_reentrant=False) if motion_module is not None else hidden_states + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index 58f22d1..d07794c 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -1,200 +1,361 @@ -# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect -from typing import Callable, List, Optional, Union +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import numpy as np +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from dataclasses import dataclass -import numpy as np -import torch -from tqdm import tqdm - -from diffusers.utils import is_accelerate_available -from packaging import version -from transformers import CLIPTextModel, CLIPTokenizer - -from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, +from animatediff.models.unet import UNet3DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, + BaseOutput, ) -from diffusers.utils import deprecate, logging, BaseOutput - from einops import rearrange +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + + +@dataclass +class AnimatePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + videos: Union[torch.Tensor, np.ndarray] -from ..models.unet import UNet3DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline -@dataclass -class AnimationPipelineOutput(BaseOutput): - videos: Union[torch.Tensor, np.ndarray] + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" -class AnimationPipeline(DiffusionPipeline): - _optional_components = [] +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class AnimationPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, unet: UNet3DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, + text_encoder_2=text_encoder_2, tokenizer=tokenizer, + tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ self.vae.enable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ self.vae.disable_slicing() - def enable_sequential_cpu_offload(self, gpu_id=0): - if is_accelerate_available(): - from accelerate import cpu_offload + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook else: - raise ImportError("Please install accelerate via `pip install accelerate`") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - @property - def _execution_device(self): - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + model_sequence.extend([self.unet, self.vae]) - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) else: - attention_mask = None + batch_size = prompt_embeds.shape[0] - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) - text_embeddings = text_embeddings[0] - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -202,55 +363,58 @@ class AnimationPipeline(DiffusionPipeline): " the batch size of `prompt`." ) else: - uncond_tokens = negative_prompt + uncond_tokens = [negative_prompt, negative_prompt_2] - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) - uncond_embeddings = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - return text_embeddings - - def decode_latents(self, latents): - video_length = latents.shape[2] - latents = 1 / 0.18215 * latents - latents = rearrange(latents, "b c f h w -> (b f) c h w") - # video = self.vae.decode(latents).sample - video = [] - for frame_idx in tqdm(range(latents.shape[0])): - video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) - video = torch.cat(video) - video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) - video = (video / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - video = video.cpu().float().numpy() - return video + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -268,10 +432,20 @@ class AnimationPipeline(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -283,131 +457,393 @@ class AnimationPipeline(DiffusionPipeline): f" {type(callback_steps)}." ) - def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, single_model_length, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, single_model_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - rand_device = "cpu" if device.type == "mps" else device - if isinstance(generator, list): - shape = shape - # shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], - video_length: Optional[int], + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + single_model_length: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, - guidance_scale: float = 7.5, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "tensor", + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, ): - # Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + r""" + Function invoked when calling the pipeline for generation. - # Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - # Define call parameters - # batch_size = 1 if isinstance(prompt, str) else len(prompt) - batch_size = 1 - if latents is not None: - batch_size = latents.shape[0] - if isinstance(prompt, list): + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # Encode input prompt - prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size - if negative_prompt is not None: - negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size - text_embeddings = self._encode_prompt( - prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) - # Prepare timesteps + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps - # Prepare latent variables - num_channels_latents = self.unet.in_channels + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, + single_model_length, num_channels_latents, - video_length, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, ) - latents_dtype = latents.dtype - # Prepare extra step kwargs. + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) - # noise_pred = [] - # import pdb - # pdb.set_trace() - # for batch_idx in range(latent_model_input.shape[0]): - # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype) - # noise_pred.append(noise_pred_single) - # noise_pred = torch.cat(noise_pred) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + ts = torch.tensor([t], dtype=latent_model_input.dtype, device=latent_model_input.device) + if do_classifier_free_guidance: + ts = ts.repeat(2) + + noise_pred = self.unet( + latent_model_input, + ts, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -415,14 +851,98 @@ class AnimationPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # Post-processing - video = self.decode_latents(latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float32 and latents.dtype == torch.float16: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - # Convert to tensor - if output_type == "tensor": - video = torch.from_numpy(video) + if not output_type == "latent": + latents = rearrange(latents, "b c f h w -> (b f) c h w") + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + #image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + image = ((image + 1) / 2).clamp(0, 1) + video = rearrange(image, "(b f) c h w -> b c f h w", f=single_model_length).cpu() if not return_dict: - return video + return (video,) - return AnimationPipelineOutput(videos=video) + return AnimatePipelineOutput(videos=video) + + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) \ No newline at end of file diff --git a/animatediff/utils/convert_from_ckpt.py b/animatediff/utils/convert_from_ckpt.py deleted file mode 100644 index 9ee269d..0000000 --- a/animatediff/utils/convert_from_ckpt.py +++ /dev/null @@ -1,959 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Conversion script for the Stable Diffusion checkpoints.""" - -import re -from io import BytesIO -from typing import Optional - -import requests -import torch -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from diffusers.models import ( - AutoencoderKL, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils.import_utils import BACKENDS_MAPPING - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params - else: - unet_params = original_config.model.params.unet_config.params - - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim = [5, 10, 20, 20] - - class_embed_type = None - projection_class_embeddings_input_dim = None - - if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels - else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": unet_params.context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - } - - if not controlnet: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), - ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - if "cond_stage_model.model.text_projection" in checkpoint: - d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer - continue - if key in textenc_conversion_map: - text_model_dict[textenc_conversion_map[key]] = checkpoint[key] - if key.startswith("cond_stage_model.model.transformer."): - new_key = key[len("cond_stage_model.model.transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config.model.params.embedder_config - - sd_clip_image_embedder_class = image_embedder_config.target - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - controlnet_model = ControlNetModel(**ctrlnet_config) - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True - ) - - controlnet_model.load_state_dict(converted_ctrl_checkpoint) - - return controlnet_model diff --git a/animatediff/utils/convert_lora_safetensor_to_diffusers.py b/animatediff/utils/convert_lora_safetensor_to_diffusers.py index 7490e38..543f03b 100644 --- a/animatediff/utils/convert_lora_safetensor_to_diffusers.py +++ b/animatediff/utils/convert_lora_safetensor_to_diffusers.py @@ -5,7 +5,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -23,132 +23,112 @@ from safetensors.torch import load_file from diffusers import StableDiffusionPipeline import pdb - - -def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): - # directly update weight in diffusers model - for key in state_dict: - # only process lora down key - if "up." in key: continue - - up_key = key.replace(".down.", ".up.") - model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") - model_key = model_key.replace("to_out.", "to_out.0.") - layer_infos = model_key.split(".")[:-1] - - curr_layer = pipeline.unet - while len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - curr_layer = curr_layer.__getattr__(temp_name) - - weight_down = state_dict[key] - weight_up = state_dict[up_key] - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) - - return pipeline - - - def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): - # load base model - # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + # load base model + # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) - # load LoRA weight from .safetensors - # state_dict = load_file(checkpoint_path) + # load LoRA weight from .safetensors + # state_dict = load_file(checkpoint_path) - visited = [] + visited = [] + # directly update weight in diffusers model + for lora_name in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" - # directly update weight in diffusers model - for key in state_dict: - # it is suggested to print out the key, it usually will be something like below - # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + # as we have set the alpha beforehand, so just skip + if ".alpha" in lora_name or lora_name in visited: + continue - # as we have set the alpha beforehand, so just skip - if ".alpha" in key or key in visited: - continue + if "te" in lora_name: + if "lora_te1" in key: + LORA_PREFIX_TEXT_ENCODER = "lora_te1" + elif "lora_te2" in key: + LORA_PREFIX_TEXT_ENCODER = "lora_te2" + else: + LORA_PREFIX_TEXT_ENCODER = "lora_te" + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + - if "text" in key: - layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") - curr_layer = pipeline.text_encoder - else: - layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") - curr_layer = pipeline.unet + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet - # find the target layer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) - pair_keys = [] - if "lora_down" in key: - pair_keys.append(key.replace("lora_down", "lora_up")) - pair_keys.append(key) - else: - pair_keys.append(key) - pair_keys.append(key.replace("lora_up", "lora_down")) + pair_keys = [] + if "lora.down" in key: + pair_keys.append(key.replace("lora.down", "lora.up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora.up", "lora.down")) - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) - weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) - else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) - # update visited list - for item in pair_keys: - visited.append(item) + # update visited list + for item in pair_keys: + visited.append(item) - return pipeline + return pipeline if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - parser.add_argument( - "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." - ) - parser.add_argument( - "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." - ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - parser.add_argument( - "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" - ) - parser.add_argument( - "--lora_prefix_text_encoder", - default="lora_te", - type=str, - help="The prefix of text encoder weight in safetensors", - ) - parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") - parser.add_argument( - "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." - ) - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - args = parser.parse_args() + args = parser.parse_args() - base_model_path = args.base_model_path - checkpoint_path = args.checkpoint_path - dump_path = args.dump_path - lora_prefix_unet = args.lora_prefix_unet - lora_prefix_text_encoder = args.lora_prefix_text_encoder - alpha = args.alpha + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha - pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) - pipe = pipe.to(args.device) - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index 5393385..1fced56 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -3,19 +3,23 @@ import imageio import numpy as np from typing import Union +import random import torch -import torchvision import torch.distributed as dist +import torchvision -from safetensors import safe_open +import diffusers from tqdm import tqdm from einops import rearrange -from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint -from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers +from safetensors import safe_open +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline + +from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora def zero_rank_print(s): - if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + if not isinstance(s, str): s = repr(s) + if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): @@ -91,67 +95,1714 @@ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) return ddim_latents -def load_weights( - animation_pipeline, - # motion module - motion_module_path = "", - motion_module_lora_configs = [], - # image layers - dreambooth_model_path = "", - lora_model_path = "", - lora_alpha = 0.8, -): - # 1.1 motion module - unet_state_dict = {} - if motion_module_path != "": - print(f"load motion module from {motion_module_path}") - motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") - motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict - unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) - - missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) - assert len(unexpected) == 0 - del unet_state_dict - if dreambooth_model_path != "": - print(f"load dreambooth model from {dreambooth_model_path}") - if dreambooth_model_path.endswith(".safetensors"): - dreambooth_state_dict = {} - with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: - for key in f.keys(): - dreambooth_state_dict[key] = f.get_tensor(key) - elif dreambooth_model_path.endswith(".ckpt"): - dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") - - # 1. vae - converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) - animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) - # 2. unet - converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) - animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) - # 3. text_model - animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) - del dreambooth_state_dict +def load_weights(pipeline, motion_module_path, ckpt_path, lora_path, lora_alpha): + # Load ckpt + + if ckpt_path != "": + ( + text_model1, + text_model2, + vae, + unet, + logit_scale, + ckpt_info + ) = load_models_from_sdxl_checkpoint(MODEL_VERSION_SDXL_BASE_V1_0, ckpt_path, 'cpu') + - if lora_model_path != "": - print(f"load lora model from {lora_model_path}") - assert lora_model_path.endswith(".safetensors") + unet_state_dict = unet.state_dict() + pipeline.unet.load_state_dict(unet_state_dict, strict=False) + pipeline.vae = vae + pipeline.text_encoder = text_model1 + pipeline.text_encoder_2 = text_model2 + del unet + del unet_state_dict + del vae + del text_model1 + del text_model2 + print(f'Loading ckpt model from {ckpt_path}') + + # Load Motion Module + if motion_module_path != "": + motion_module_ckpt = torch.load(motion_module_path, map_location='cpu') + + motion_module_state_dict = {} + m_k = None + for k, v in motion_module_ckpt.items(): + if 'motion_module' in k and k in pipeline.unet.state_dict().keys(): + motion_module_state_dict[k] = v + m_k = k + elif 'motion_module' in k and k not in pipeline.unet.state_dict().keys(): + print(k) + + pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + del motion_module_ckpt + del motion_module_state_dict + print(f'Loading motion module from {motion_module_path}...') + + # Load LoRA + if lora_path != "": lora_state_dict = {} - with safe_open(lora_model_path, framework="pt", device="cpu") as f: - for key in f.keys(): - lora_state_dict[key] = f.get_tensor(key) - - animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) - del lora_state_dict + with safe_open(lora_path, framework='pt', device='cpu') as f: + for k in f.keys(): + lora_state_dict[k] = f.get_tensor(k) + for k, v in lora_state_dict.items(): + if 'lora.up' in k: + + down_key = k.replace('lora.up', 'lora.down') + if 'to_out' not in k: + original_key = k.replace('processor.', '').replace('_lora.up', '') + else: + original_key = k.replace('processor.', '').replace('_lora.up', '.0') + pipeline.unet.state_dict()[original_key] += lora_alpha * torch.mm(v, lora_state_dict[down_key]) + print(f'Loading lora model from {lora_path}') + + return pipeline - for motion_module_lora_config in motion_module_lora_configs: - path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] - print(f"load motion LoRA from {path}") - motion_lora_state_dict = torch.load(path, map_location="cpu") - motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts=0, caption_column='text', is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] - animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} + + +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" +from safetensors.torch import load_file +from accelerate.utils.modeling import set_module_tensor_to_device +from animatediff.utils.xl_lora_util import SdxlUNet2DConditionModel +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from packaging import version +from accelerate import init_empty_weights +import transformers +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, unet_only=False): + # model_version is reserved for future use + # dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching + + # Load the state dict + if is_safetensors(ckpt_path): + checkpoint = None + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + print("building U-Net") + with init_empty_weights(): + unet = SdxlUNet2DConditionModel() + + print("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location) + print("U-Net: ", info) + unet_sd= unet.state_dict() + du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) + from animatediff.models.unet import UNet3DConditionModel + diffusers_unet = UNet3DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) + diffusers_unet.load_state_dict(du_unet_sd) + + if unet_only: + return None, None, None, diffusers_unet, None, None + + # Text Encoders + print("building text encoders") + + # Text Encoder 1 is same to Stability AI's SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + + print("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + if version.parse(transformers.__version__) > version.parse("4.31"): + # After transformers==4.31, position_ids becomes a persistent=False buffer (so we musn't supply it) + # https://github.com/huggingface/transformers/pull/24505 + # https://github.com/mlfoundations/open_clip/pull/595 + if 'text_model.embeddings.position_ids' in te1_sd: + del te1_sd['text_model.embeddings.position_ids'] + info1 = text_model1.load_state_dict(te1_sd) + print("text encoder 1:", info1) + + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + info2 = text_model2.load_state_dict(converted_sd) + print("text encoder 2:", info2) + + # prepare vae + print("building VAE") + vae_config = create_vae_diffusers_config() + vae = AutoencoderKL(**vae_config) # .to(device) + + print("loading VAE from checkpoint") + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("VAE:", info) + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, diffusers_unet, logit_scale, ckpt_info + + + # load state_dict without allocating new tensors + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + if diffusers.__version__ < "0.17.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + reshaping = False + if diffusers.__version__ < "0.17.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get("use_linear_projection", False): + linear_transformer_to_conv(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の変換コード +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +def controlnet_conversion_map(): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + if diffusers.__version__ < "0.17.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自作のモデル読み書きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +def get_model_version_str_for_sd1_sd2(v2, v_parameterization): + # only for reference + version_str = "sd" + if v2: + version_str += "_v2" + else: + version_str += "_v1" + if v_parameterization: + version_str += "_v" + return version_str + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint( + v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None +): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# load state_dict without allocating new tensors +def _load_state_dict_on_device(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return "" + + # error_msgs + error_msgs: List[str] = [] + if missing_keys: + error_msgs.insert(0, "Missing key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in missing_keys))) + if unexpected_keys: + error_msgs.insert(0, "Unexpected key(s) in state_dict: {}. ".format(", ".join('"{}"'.format(k) for k in unexpected_keys))) + + raise RuntimeError("Error(s) in loading state_dict for {}:\n\t{}".format(model.__class__.__name__, "\n\t".join(error_msgs))) + + +def convert_sdxl_text_encoder_2_checkpoint(checkpoint, max_length): + SDXL_KEY_PREFIX = "conditioner.embedders.1.model." + + # SD2のと、基本的には同じ。logit_scaleを後で使うので、それを追加で返す + # logit_scaleはcheckpointの保存時に使用する + def convert_key(key): + # common conversion + key = key.replace(SDXL_KEY_PREFIX + "transformer.", "text_model.encoder.") + key = key.replace(SDXL_KEY_PREFIX, "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = key.replace("text_model.text_projection", "text_projection.weight") + elif ".logit_scale" in key: + key = None # 後で処理する + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + # ckpt from comfy has this key: text_model.encoder.text_model.embeddings.position_ids + elif ".embeddings.position_ids" in key: + key = None # remove this key: make position_ids by ourselves + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace(SDXL_KEY_PREFIX + "transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # Create position_ids only for *old* transformers versions. + # After transformers==4.31, position_ids becomes a persistent=False buffer (so we musn't supply it) + # https://github.com/huggingface/transformers/pull/24505 + # https://github.com/mlfoundations/open_clip/pull/595 + if version.parse(transformers.__version__) <= version.parse("4.31"): + # original SD にはないので、position_idsを追加 + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + new_sd["text_model.embeddings.position_ids"] = position_ids + + # logit_scale はDiffusersには含まれないが、保存時に戻したいので別途返す + logit_scale = checkpoint.get(SDXL_KEY_PREFIX + "logit_scale", None) + + return new_sd, logit_scale + +VAE_SCALE_FACTOR = 0.13025 +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" + +# Diffusersの設定を読み込むための参照モデル +DIFFUSERS_REF_MODEL_ID_SDXL = "stabilityai/stable-diffusion-xl-base-1.0" + +DIFFUSERS_SDXL_UNET_CONFIG = { + "act_fn": "silu", + "addition_embed_type": "text_time", + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": 256, + "attention_head_dim": [5, 10, 20], + "block_out_channels": [320, 640, 1280], + "center_input_sample": False, + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 2048, + "cross_attention_norm": None, + "down_block_types": ["DownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D"], + "downsample_padding": 1, + "dual_cross_attention": False, + "encoder_hid_dim": None, + "encoder_hid_dim_type": None, + "flip_sin_to_cos": True, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_only_cross_attention": None, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock3DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": None, + "num_class_embeds": None, + "only_cross_attention": False, + "out_channels": 4, + "projection_class_embeddings_input_dim": 2816, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, + "resnet_time_scale_shift": "default", + "sample_size": 128, + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "transformer_layers_per_block": [1, 2, 10], + "up_block_types": ["CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "UpBlock3D"], + "upcast_attention": False, + "use_linear_projection": True, +} + + +NUM_TRAIN_TIMESTEPS = 1000 +BETA_START = 0.00085 +BETA_END = 0.0120 + +UNET_PARAMS_MODEL_CHANNELS = 320 +UNET_PARAMS_CHANNEL_MULT = [1, 2, 4, 4] +UNET_PARAMS_ATTENTION_RESOLUTIONS = [4, 2, 1] +UNET_PARAMS_IMAGE_SIZE = 64 # fixed from old invalid value `32` +UNET_PARAMS_IN_CHANNELS = 4 +UNET_PARAMS_OUT_CHANNELS = 4 +UNET_PARAMS_NUM_RES_BLOCKS = 2 +UNET_PARAMS_CONTEXT_DIM = 768 +UNET_PARAMS_NUM_HEADS = 8 +# UNET_PARAMS_USE_LINEAR_PROJECTION = False + +VAE_PARAMS_Z_CHANNELS = 4 +VAE_PARAMS_RESOLUTION = 256 +VAE_PARAMS_IN_CHANNELS = 3 +VAE_PARAMS_OUT_CH = 3 +VAE_PARAMS_CH = 128 +VAE_PARAMS_CH_MULT = [1, 2, 4, 4] +VAE_PARAMS_NUM_RES_BLOCKS = 2 + +# V2 +V2_UNET_PARAMS_ATTENTION_HEAD_DIM = [5, 10, 20, 20] +V2_UNET_PARAMS_CONTEXT_DIM = 1024 +# V2_UNET_PARAMS_USE_LINEAR_PROJECTION = True + +def convert_sdxl_unet_state_dict_to_diffusers(sd): + unet_conversion_map = make_unet_conversion_map() + + conversion_dict = {sd: hf for sd, hf in unet_conversion_map} + return convert_unet_state_dict(sd, conversion_dict) + +def make_unet_conversion_map(): + unet_conversion_map_layer = [] + + for i in range(3): # num_blocks is 3 in sdxl + # loop over downblocks/upblocks + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + # if i > 0: commentout for sdxl + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{2}." # change for sdxl + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0.", "norm1."), + ("in_layers.2.", "conv1."), + ("out_layers.0.", "norm2."), + ("out_layers.3.", "conv2."), + ("emb_layers.1.", "time_emb_proj."), + ("skip_connection.", "conv_shortcut."), + ] + + unet_conversion_map = [] + for sd, hf in unet_conversion_map_layer: + if "resnets" in hf: + for sd_res, hf_res in unet_conversion_map_resnet: + unet_conversion_map.append((sd + sd_res, hf + hf_res)) + else: + unet_conversion_map.append((sd, hf)) + + for j in range(2): + hf_time_embed_prefix = f"time_embedding.linear_{j+1}." + sd_time_embed_prefix = f"time_embed.{j*2}." + unet_conversion_map.append((sd_time_embed_prefix, hf_time_embed_prefix)) + + for j in range(2): + hf_label_embed_prefix = f"add_embedding.linear_{j+1}." + sd_label_embed_prefix = f"label_emb.0.{j*2}." + unet_conversion_map.append((sd_label_embed_prefix, hf_label_embed_prefix)) + + unet_conversion_map.append(("input_blocks.0.0.", "conv_in.")) + unet_conversion_map.append(("out.0.", "conv_norm_out.")) + unet_conversion_map.append(("out.2.", "conv_out.")) + + return unet_conversion_map + + +def convert_unet_state_dict(src_sd, conversion_map): + converted_sd = {} + for src_key, value in src_sd.items(): + # さすがに全部回すのは時間がかかるので右から要素を削りつつprefixを探す + src_key_fragments = src_key.split(".")[:-1] # remove weight/bias + while len(src_key_fragments) > 0: + src_key_prefix = ".".join(src_key_fragments) + "." + if src_key_prefix in conversion_map: + converted_prefix = conversion_map[src_key_prefix] + converted_key = converted_prefix + src_key[len(src_key_prefix) :] + converted_sd[converted_key] = value + break + src_key_fragments.pop(-1) + assert len(src_key_fragments) > 0, f"key {src_key} not found in conversion map" + + return converted_sd - return animation_pipeline diff --git a/animatediff/utils/xl_lora_util.py b/animatediff/utils/xl_lora_util.py new file mode 100644 index 0000000..3d26044 --- /dev/null +++ b/animatediff/utils/xl_lora_util.py @@ -0,0 +1,1095 @@ +# Diffusersのコードをベースとした sd_xl_baseのU-Net +# state dictの形式をSDXLに合わせてある + +""" + target: sgm.modules.diffusionmodules.openaimodel.UNetModel + params: + adm_in_channels: 2816 + num_classes: sequential + use_checkpoint: True + in_channels: 4 + out_channels: 4 + model_channels: 320 + attention_resolutions: [4, 2] + num_res_blocks: 2 + channel_mult: [1, 2, 4] + num_head_channels: 64 + use_spatial_transformer: True + use_linear_in_transformer: True + transformer_depth: [1, 2, 10] # note: the first is unused (due to attn_res starting at 2) 32, 16, 8 --> 64, 32, 16 + context_dim: 2048 + spatial_transformer_attn_type: softmax-xformers + legacy: False +""" + +import math +from types import SimpleNamespace +from typing import Optional +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import functional as F +from einops import rearrange + + +IN_CHANNELS: int = 4 +OUT_CHANNELS: int = 4 +ADM_IN_CHANNELS: int = 2816 +CONTEXT_DIM: int = 2048 +MODEL_CHANNELS: int = 320 +TIME_EMBED_DIM = 320 * 4 + +USE_REENTRANT = True + +# region memory effcient attention + +# FlashAttentionを使うCrossAttention +# based on https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/memory_efficient_attention_pytorch/flash_attention.py +# LICENSE MIT https://github.com/lucidrains/memory-efficient-attention-pytorch/blob/main/LICENSE + +# constants + +EPSILON = 1e-6 + +# helper functions + + +def exists(val): + return val is not None + + +def default(val, d): + return val if exists(val) else d + + +# flash attention forwards and backwards + +# https://arxiv.org/abs/2205.14135 + + +class FlashAttentionFunction(torch.autograd.Function): + @staticmethod + @torch.no_grad() + def forward(ctx, q, k, v, mask, causal, q_bucket_size, k_bucket_size): + """Algorithm 2 in the paper""" + + device = q.device + dtype = q.dtype + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + o = torch.zeros_like(q) + all_row_sums = torch.zeros((*q.shape[:-1], 1), dtype=dtype, device=device) + all_row_maxes = torch.full((*q.shape[:-1], 1), max_neg_value, dtype=dtype, device=device) + + scale = q.shape[-1] ** -0.5 + + if not exists(mask): + mask = (None,) * math.ceil(q.shape[-2] / q_bucket_size) + else: + mask = rearrange(mask, "b n -> b 1 1 n") + mask = mask.split(q_bucket_size, dim=-1) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + mask, + all_row_sums.split(q_bucket_size, dim=-2), + all_row_maxes.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, row_mask, row_sums, row_maxes) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if exists(row_mask): + attn_weights.masked_fill_(~row_mask, max_neg_value) + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + block_row_maxes = attn_weights.amax(dim=-1, keepdims=True) + attn_weights -= block_row_maxes + exp_weights = torch.exp(attn_weights) + + if exists(row_mask): + exp_weights.masked_fill_(~row_mask, 0.0) + + block_row_sums = exp_weights.sum(dim=-1, keepdims=True).clamp(min=EPSILON) + + new_row_maxes = torch.maximum(block_row_maxes, row_maxes) + + exp_values = torch.einsum("... i j, ... j d -> ... i d", exp_weights, vc) + + exp_row_max_diff = torch.exp(row_maxes - new_row_maxes) + exp_block_row_max_diff = torch.exp(block_row_maxes - new_row_maxes) + + new_row_sums = exp_row_max_diff * row_sums + exp_block_row_max_diff * block_row_sums + + oc.mul_((row_sums / new_row_sums) * exp_row_max_diff).add_((exp_block_row_max_diff / new_row_sums) * exp_values) + + row_maxes.copy_(new_row_maxes) + row_sums.copy_(new_row_sums) + + ctx.args = (causal, scale, mask, q_bucket_size, k_bucket_size) + ctx.save_for_backward(q, k, v, o, all_row_sums, all_row_maxes) + + return o + + @staticmethod + @torch.no_grad() + def backward(ctx, do): + """Algorithm 4 in the paper""" + + causal, scale, mask, q_bucket_size, k_bucket_size = ctx.args + q, k, v, o, l, m = ctx.saved_tensors + + device = q.device + + max_neg_value = -torch.finfo(q.dtype).max + qk_len_diff = max(k.shape[-2] - q.shape[-2], 0) + + dq = torch.zeros_like(q) + dk = torch.zeros_like(k) + dv = torch.zeros_like(v) + + row_splits = zip( + q.split(q_bucket_size, dim=-2), + o.split(q_bucket_size, dim=-2), + do.split(q_bucket_size, dim=-2), + mask, + l.split(q_bucket_size, dim=-2), + m.split(q_bucket_size, dim=-2), + dq.split(q_bucket_size, dim=-2), + ) + + for ind, (qc, oc, doc, row_mask, lc, mc, dqc) in enumerate(row_splits): + q_start_index = ind * q_bucket_size - qk_len_diff + + col_splits = zip( + k.split(k_bucket_size, dim=-2), + v.split(k_bucket_size, dim=-2), + dk.split(k_bucket_size, dim=-2), + dv.split(k_bucket_size, dim=-2), + ) + + for k_ind, (kc, vc, dkc, dvc) in enumerate(col_splits): + k_start_index = k_ind * k_bucket_size + + attn_weights = torch.einsum("... i d, ... j d -> ... i j", qc, kc) * scale + + if causal and q_start_index < (k_start_index + k_bucket_size - 1): + causal_mask = torch.ones((qc.shape[-2], kc.shape[-2]), dtype=torch.bool, device=device).triu( + q_start_index - k_start_index + 1 + ) + attn_weights.masked_fill_(causal_mask, max_neg_value) + + exp_attn_weights = torch.exp(attn_weights - mc) + + if exists(row_mask): + exp_attn_weights.masked_fill_(~row_mask, 0.0) + + p = exp_attn_weights / lc + + dv_chunk = torch.einsum("... i j, ... i d -> ... j d", p, doc) + dp = torch.einsum("... i d, ... j d -> ... i j", doc, vc) + + D = (doc * oc).sum(dim=-1, keepdims=True) + ds = p * scale * (dp - D) + + dq_chunk = torch.einsum("... i j, ... j d -> ... i d", ds, kc) + dk_chunk = torch.einsum("... i j, ... i d -> ... j d", ds, qc) + + dqc.add_(dq_chunk) + dkc.add_(dk_chunk) + dvc.add_(dv_chunk) + + return dq, dk, dv, None, None, None, None + + +# endregion + + +def get_parameter_dtype(parameter: torch.nn.Module): + return next(parameter.parameters()).dtype + + +def get_parameter_device(parameter: torch.nn.Module): + return next(parameter.parameters()).device + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings: flipped from Diffusers original ver because always flip_sin_to_cos=True + emb = torch.cat([torch.cos(emb), torch.sin(emb)], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + if self.weight.dtype != torch.float32: + return super().forward(x) + return super().forward(x.float()).type(x.dtype) + + +class ResnetBlock2D(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + + self.in_layers = nn.Sequential( + GroupNorm32(32, in_channels), + nn.SiLU(), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + self.emb_layers = nn.Sequential(nn.SiLU(), nn.Linear(TIME_EMBED_DIM, out_channels)) + + self.out_layers = nn.Sequential( + GroupNorm32(32, out_channels), + nn.SiLU(), + nn.Identity(), # to make state_dict compatible with original model + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1), + ) + + if in_channels != out_channels: + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + else: + self.skip_connection = nn.Identity() + + self.gradient_checkpointing = False + + def forward_body(self, x, emb): + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + h = h + emb_out[:, :, None, None] + h = self.out_layers(h) + x = self.skip_connection(x) + return x + h + + def forward(self, x, emb): + if self.training and self.gradient_checkpointing: + # print("ResnetBlock2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + x = torch.utils.checkpoint.checkpoint(create_custom_forward(self.forward_body), x, emb, use_reentrant=USE_REENTRANT) + else: + x = self.forward_body(x, emb) + + return x + + +class Downsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + + self.channels = channels + self.out_channels = out_channels + + self.op = nn.Conv2d(self.channels, self.out_channels, 3, stride=2, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states): + assert hidden_states.shape[1] == self.channels + hidden_states = self.op(hidden_states) + + return hidden_states + + def forward(self, hidden_states): + if self.training and self.gradient_checkpointing: + # print("Downsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states) + + return hidden_states + + +class CrossAttention(nn.Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast_attention: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=False) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=False) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + # no dropout here + + self.use_memory_efficient_attention_xformers = False + self.use_memory_efficient_attention_mem_eff = False + self.use_sdpa = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + self.use_memory_efficient_attention_xformers = xformers + self.use_memory_efficient_attention_mem_eff = mem_eff + + def set_use_sdpa(self, sdpa): + self.use_sdpa = sdpa + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states, context=None, mask=None): + if self.use_memory_efficient_attention_xformers: + return self.forward_memory_efficient_xformers(hidden_states, context, mask) + if self.use_memory_efficient_attention_mem_eff: + return self.forward_memory_efficient_mem_eff(hidden_states, context, mask) + if self.use_sdpa: + return self.forward_sdpa(hidden_states, context, mask) + + query = self.to_q(hidden_states) + context = context if context is not None else hidden_states + key = self.to_k(context) + value = self.to_v(context) + + query = self.reshape_heads_to_batch_dim(query) + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + hidden_states = self._attention(query, key, value) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + # hidden_states = self.to_out[1](hidden_states) # no dropout + return hidden_states + + def _attention(self, query, key, value): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + # TODO support Hypernetworks + def forward_memory_efficient_xformers(self, x, context=None, mask=None): + import xformers.ops + + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b n h d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) # 最適なのを選んでくれる + del q, k, v + + out = rearrange(out, "b n h d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + def forward_memory_efficient_mem_eff(self, x, context=None, mask=None): + flash_func = FlashAttentionFunction + + q_bucket_size = 512 + k_bucket_size = 1024 + + h = self.heads + q = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k = self.to_k(context) + v = self.to_v(context) + del context, x + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) + + out = flash_func.apply(q, k, v, mask, False, q_bucket_size, k_bucket_size) + + out = rearrange(out, "b h n d -> b n (h d)") + + out = self.to_out[0](out) + return out + + def forward_sdpa(self, x, context=None, mask=None): + h = self.heads + q_in = self.to_q(x) + context = context if context is not None else x + context = context.to(x.dtype) + k_in = self.to_k(context) + v_in = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in + + out = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + out = rearrange(out, "b h n d -> b n (h d)", h=h) + + out = self.to_out[0](out) + return out + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class FeedForward(nn.Module): + def __init__( + self, + dim: int, + ): + super().__init__() + inner_dim = int(dim * 4) # mult is always 4 + + self.net = nn.ModuleList([]) + # project in + self.net.append(GEGLU(dim, inner_dim)) + # project dropout + self.net.append(nn.Identity()) # nn.Dropout(0)) # dummy for dropout with 0 + # project out + self.net.append(nn.Linear(inner_dim, dim)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, dim: int, num_attention_heads: int, attention_head_dim: int, cross_attention_dim: int, upcast_attention: bool = False + ): + super().__init__() + + self.gradient_checkpointing = False + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + cross_attention_dim=None, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + self.ff = FeedForward(dim) + + # 2. Cross-Attn + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + upcast_attention=upcast_attention, + ) + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool): + self.attn1.set_use_memory_efficient_attention(xformers, mem_eff) + self.attn2.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool): + self.attn1.set_use_sdpa(sdpa) + self.attn2.set_use_sdpa(sdpa) + + def forward_body(self, hidden_states, context=None, timestep=None): + # 1. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + hidden_states = self.attn1(norm_hidden_states) + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + def forward(self, hidden_states, context=None, timestep=None): + if self.training and self.gradient_checkpointing: + # print("BasicTransformerBlock: checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + output = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, context, timestep, use_reentrant=USE_REENTRANT + ) + else: + output = self.forward_body(hidden_states, context, timestep) + + return output + + +class Transformer2DModel(nn.Module): + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + use_linear_projection: bool = False, + upcast_attention: bool = False, + num_transformer_layers: int = 1, + ): + super().__init__() + self.in_channels = in_channels + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.use_linear_projection = use_linear_projection + + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + # self.norm = GroupNorm32(32, in_channels, eps=1e-6, affine=True) + + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + blocks = [] + for _ in range(num_transformer_layers): + blocks.append( + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + upcast_attention=upcast_attention, + ) + ) + + self.transformer_blocks = nn.ModuleList(blocks) + + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + self.gradient_checkpointing = False + + def set_use_memory_efficient_attention(self, xformers, mem_eff): + for transformer in self.transformer_blocks: + transformer.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa): + for transformer in self.transformer_blocks: + transformer.set_use_sdpa(sdpa) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None): + # 1. Input + batch, _, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep) + + # 3. Output + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + return output + + +class Upsample2D(nn.Module): + def __init__(self, channels, out_channels): + super().__init__() + self.channels = channels + self.out_channels = out_channels + self.conv = nn.Conv2d(self.channels, self.out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward_body(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + # TODO(Suraj): Remove this cast once the issue is fixed in PyTorch + # https://github.com/pytorch/pytorch/issues/86679 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + hidden_states = self.conv(hidden_states) + + return hidden_states + + def forward(self, hidden_states, output_size=None): + if self.training and self.gradient_checkpointing: + # print("Upsample2D: gradient_checkpointing") + + def create_custom_forward(func): + def custom_forward(*inputs): + return func(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(self.forward_body), hidden_states, output_size, use_reentrant=USE_REENTRANT + ) + else: + hidden_states = self.forward_body(hidden_states, output_size) + + return hidden_states + + +class SdxlUNet2DConditionModel(nn.Module): + _supports_gradient_checkpointing = True + + def __init__( + self, + **kwargs, + ): + super().__init__() + + self.in_channels = IN_CHANNELS + self.out_channels = OUT_CHANNELS + self.model_channels = MODEL_CHANNELS + self.time_embed_dim = TIME_EMBED_DIM + self.adm_in_channels = ADM_IN_CHANNELS + + self.gradient_checkpointing = False + # self.sample_size = sample_size + + # time embedding + self.time_embed = nn.Sequential( + nn.Linear(self.model_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + + # label embedding + self.label_emb = nn.Sequential( + nn.Sequential( + nn.Linear(self.adm_in_channels, self.time_embed_dim), + nn.SiLU(), + nn.Linear(self.time_embed_dim, self.time_embed_dim), + ) + ) + + # input + self.input_blocks = nn.ModuleList( + [ + nn.Sequential( + nn.Conv2d(self.in_channels, self.model_channels, kernel_size=3, padding=(1, 1)), + ) + ] + ) + + # level 0 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=1 * self.model_channels, + out_channels=1 * self.model_channels, + ), + ) + ) + + # level 1 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(1 if i == 0 else 2) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + self.input_blocks.append( + nn.Sequential( + Downsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ), + ) + ) + + # level 2 + for i in range(2): + layers = [ + ResnetBlock2D( + in_channels=(2 if i == 0 else 4) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + self.input_blocks.append(nn.ModuleList(layers)) + + # mid + self.middle_block = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ResnetBlock2D( + in_channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ), + ] + ) + + # output + self.output_blocks = nn.ModuleList([]) + + # level 2 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=4 * self.model_channels + (4 if i <= 1 else 2) * self.model_channels, + out_channels=4 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=4 * self.model_channels // 64, + attention_head_dim=64, + in_channels=4 * self.model_channels, + num_transformer_layers=10, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=4 * self.model_channels, + out_channels=4 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 1 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=2 * self.model_channels + (4 if i == 0 else (2 if i == 1 else 1)) * self.model_channels, + out_channels=2 * self.model_channels, + ), + Transformer2DModel( + num_attention_heads=2 * self.model_channels // 64, + attention_head_dim=64, + in_channels=2 * self.model_channels, + num_transformer_layers=2, + use_linear_projection=True, + cross_attention_dim=2048, + ), + ] + if i == 2: + layers.append( + Upsample2D( + channels=2 * self.model_channels, + out_channels=2 * self.model_channels, + ) + ) + + self.output_blocks.append(nn.ModuleList(layers)) + + # level 0 + for i in range(3): + layers = [ + ResnetBlock2D( + in_channels=1 * self.model_channels + (2 if i == 0 else 1) * self.model_channels, + out_channels=1 * self.model_channels, + ), + ] + + self.output_blocks.append(nn.ModuleList(layers)) + + # output + self.out = nn.ModuleList( + [GroupNorm32(32, self.model_channels), nn.SiLU(), nn.Conv2d(self.model_channels, self.out_channels, 3, padding=1)] + ) + + # region diffusers compatibility + def prepare_config(self): + self.config = SimpleNamespace() + + @property + def dtype(self) -> torch.dtype: + # `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + return get_parameter_dtype(self) + + @property + def device(self) -> torch.device: + # `torch.device`: The device on which the module is (assuming that all the module parameters are on the same device). + return get_parameter_device(self) + + def set_attention_slice(self, slice_size): + raise NotImplementedError("Attention slicing is not supported for this model.") + + def is_gradient_checkpointing(self) -> bool: + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + self.gradient_checkpointing = True + self.set_gradient_checkpointing(value=True) + + def disable_gradient_checkpointing(self): + self.gradient_checkpointing = False + self.set_gradient_checkpointing(value=False) + + def set_use_memory_efficient_attention(self, xformers: bool, mem_eff: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_memory_efficient_attention"): + # print(module.__class__.__name__) + module.set_use_memory_efficient_attention(xformers, mem_eff) + + def set_use_sdpa(self, sdpa: bool) -> None: + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block: + if hasattr(module, "set_use_sdpa"): + module.set_use_sdpa(sdpa) + + def set_gradient_checkpointing(self, value=False): + blocks = self.input_blocks + [self.middle_block] + self.output_blocks + for block in blocks: + for module in block.modules(): + if hasattr(module, "gradient_checkpointing"): + # print(module.__class__.__name__, module.gradient_checkpointing, "->", value) + module.gradient_checkpointing = value + + # endregion + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + # broadcast timesteps to batch dimension + timesteps = timesteps.expand(x.shape[0]) + + hs = [] + t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False) + t_emb = t_emb.to(x.dtype) + emb = self.time_embed(t_emb) + + assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}" + assert x.dtype == y.dtype, f"dtype mismatch: {x.dtype} != {y.dtype}" + # assert x.dtype == self.dtype + emb = emb + self.label_emb(y) + + def call_module(module, h, emb, context): + x = h + for layer in module: + # print(layer.__class__.__name__, x.dtype, emb.dtype, context.dtype if context is not None else None) + if isinstance(layer, ResnetBlock2D): + x = layer(x, emb) + elif isinstance(layer, Transformer2DModel): + x = layer(x, context) + else: + x = layer(x) + return x + + # h = x.type(self.dtype) + h = x + for module in self.input_blocks: + h = call_module(module, h, emb, context) + hs.append(h) + + h = call_module(self.middle_block, h, emb, context) + + for module in self.output_blocks: + h = torch.cat([h, hs.pop()], dim=1) + h = call_module(module, h, emb, context) + + h = h.type(x.dtype) + h = call_module(self.out, h, emb, context) + + return h + + diff --git a/configs/inference/inference-v2.yaml b/configs/inference/inference-v2.yaml deleted file mode 100644 index a33bc12..0000000 --- a/configs/inference/inference-v2.yaml +++ /dev/null @@ -1,27 +0,0 @@ -unet_additional_kwargs: - use_inflated_groupnorm: true - unet_use_cross_frame_attention: false - unet_use_temporal_attention: false - use_motion_module: true - motion_module_resolutions: - - 1 - - 2 - - 4 - - 8 - motion_module_mid_block: true - motion_module_decoder_only: false - motion_module_type: Vanilla - motion_module_kwargs: - num_attention_heads: 8 - num_transformer_block: 1 - attention_block_types: - - Temporal_Self - - Temporal_Self - temporal_position_encoding: true - temporal_position_encoding_max_len: 32 - temporal_attention_dim_div: 1 - -noise_scheduler_kwargs: - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "linear" diff --git a/configs/inference/inference-v1.yaml b/configs/inference/inference.yaml similarity index 64% rename from configs/inference/inference-v1.yaml rename to configs/inference/inference.yaml index 86f3777..e9e52a3 100644 --- a/configs/inference/inference-v1.yaml +++ b/configs/inference/inference.yaml @@ -1,6 +1,4 @@ unet_additional_kwargs: - unet_use_cross_frame_attention: false - unet_use_temporal_attention: false use_motion_module: true motion_module_resolutions: - 1 @@ -8,7 +6,6 @@ unet_additional_kwargs: - 4 - 8 motion_module_mid_block: false - motion_module_decoder_only: false motion_module_type: Vanilla motion_module_kwargs: num_attention_heads: 8 @@ -17,10 +14,14 @@ unet_additional_kwargs: - Temporal_Self - Temporal_Self temporal_position_encoding: true - temporal_position_encoding_max_len: 24 + temporal_position_encoding_max_len: 32 temporal_attention_dim_div: 1 + noise_scheduler_kwargs: - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "linear" + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.020 + beta_schedule: "scaled_linear" + + diff --git a/configs/prompts/1-ToonYou.yaml b/configs/prompts/1-ToonYou.yaml deleted file mode 100644 index 925e933..0000000 --- a/configs/prompts/1-ToonYou.yaml +++ /dev/null @@ -1,23 +0,0 @@ -ToonYou: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors" - lora_model_path: "" - - seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress" - - "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes," - - "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern" - - "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle," - - n_prompt: - - "" - - "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth" - - "" - - "" diff --git a/configs/prompts/1-original_sdxl.yaml b/configs/prompts/1-original_sdxl.yaml new file mode 100644 index 0000000..4ba4508 --- /dev/null +++ b/configs/prompts/1-original_sdxl.yaml @@ -0,0 +1,14 @@ +motion_module_path: Path to Motion Module + +seed: -1 + +guidance_scale: 8.5 +steps: 100 + +prompt: + - "A panda standing on a surfboard in the ocean in sunset, 4k, high resolution.Realistic, Cinematic, high resolution" + - "A disoriented astronaut, lost in a galaxy of swirling colors, floating in zero gravity, grasping at memories, poignant loneliness, stunning realism, cosmic chaos, emotional depth, 12K, hyperrealism, unforgettable, mixed media, celestial, dark, introspective" + +n_prompt: + - "" + - "" \ No newline at end of file diff --git a/configs/prompts/2-DynaVision.yaml b/configs/prompts/2-DynaVision.yaml new file mode 100644 index 0000000..74452aa --- /dev/null +++ b/configs/prompts/2-DynaVision.yaml @@ -0,0 +1,17 @@ +base_model_path: Path to DynaVision +motion_module_path: Path to Motion Module + +seed: -1 + +guidance_scale: 8.5 +steps: 100 + +prompt: + - "cinematic color grading lighting vintage realistic film grain scratches celluloid analog cool shadows warm highlights soft focus actor directed cinematography technicolor , confused, looking around scared , + / A lanky blonde haired man on vacation enjoying the local party scene in Brisbane at dawn" + - "Pixel Art pixelated pixel pixel , extremely content happy smile , + / A fit pink haired woman at night in the city, holding a trick or treat bag, wearing a kitty ears halloween costume, (big eyes:1.3) by Dreamworks Studios" + - "anime artwork of [water|a galaxy|thunderstorm] inside a bottle" + +n_prompt: + - "" + - "" + - "hand" \ No newline at end of file diff --git a/configs/prompts/2-Lyriel.yaml b/configs/prompts/2-Lyriel.yaml deleted file mode 100644 index 583b0ce..0000000 --- a/configs/prompts/2-Lyriel.yaml +++ /dev/null @@ -1,23 +0,0 @@ -Lyriel: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors" - lora_model_path: "" - - seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange" - - "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal" - - "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray" - - "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown." - - n_prompt: - - "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration" - - "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular" - - "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome" - - "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render" diff --git a/configs/prompts/3-DreamShaper.yaml b/configs/prompts/3-DreamShaper.yaml new file mode 100644 index 0000000..3855ad3 --- /dev/null +++ b/configs/prompts/3-DreamShaper.yaml @@ -0,0 +1,16 @@ +ckpt_path: Path to DreamShaper +motion_module_path: Path to Motion Module + +seed: -1 + +guidance_scale: 10 +steps: 100 + +prompt: + - "photo of a supercar, 8k uhd, high quality, road, sunset, motion blur, depth blur, cinematic, filmic image 4k, 8k. Natural sunlight, vibrant color, reflections" + - "night view, sea beach landscape, with a lighthouse, trending on artstation, by Noah Bradley, highly detailed, high quality, 4k HDR, path tracking, calm landscape, high consistency, soft lighting, in Serene grids, transcendent fields, quiet repetitions, sublime reduction in Brooding landscapes, epic scale, German myth, layered symbolic density" + +n_prompt: + - "embedding:BadDream, embedding:UnrealisticDream" + - "(worst quality), (low quality), (normal quality), lowres, blurry, hand, signature, normal quality" + \ No newline at end of file diff --git a/configs/prompts/3-RcnzCartoon.yaml b/configs/prompts/3-RcnzCartoon.yaml deleted file mode 100644 index fd76a2b..0000000 --- a/configs/prompts/3-RcnzCartoon.yaml +++ /dev/null @@ -1,23 +0,0 @@ -RcnzCartoon: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors" - lora_model_path: "" - - seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded" - - "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face" - - "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes" - - "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering" - - n_prompt: - - "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation" - - "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular" - - "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face," - - "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand" diff --git a/configs/prompts/4-DeepBlue.yaml b/configs/prompts/4-DeepBlue.yaml new file mode 100644 index 0000000..f982043 --- /dev/null +++ b/configs/prompts/4-DeepBlue.yaml @@ -0,0 +1,20 @@ +ckpt_path: Path to DeepBlud model +motion_module_path: Path to Motion Module + +seed: -1 + +guidance_scale: 7 +steps: 100 + +prompt: + - "highres,best quality,natural, A photo of a mummified Miffy, the iconic white bunny character, lying peacefully on a velvet cushion inside an intricately decorated glass display case, with its fur preserved in a pristine condition, its eyes closed peacefully, and its little paws gently crossed over each other, evoking a sense of tranquility and reverence.,cinematic photo official art, 8k wallpaper,ultra detailed, aesthetic quality,photorealistic,entangle,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic,Perspective" + - "highres,best quality,natural, Cat pressing the Delete key on the keyboard,cinematic photo official art, 8k wallpaper,ultra detailed, aesthetic quality,photorealistic,entangle,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic,Perspective" + - "highres,best quality,natural, A psychedelic world spreads out below, Witch kawaii girl floating in the sky Wearing a witch hat Frilled fantasy costume Blonde straight hair,cinematic photo official art,unity 8k wallpaper,ultra detailed,aesthetic,masterpiece,best quality,photorealistic,entangle,mandala,tangle,entangle,cstasy of flower,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic," + +n_prompt: + - "bad quality,worst quality" + - "bad quality,worst quality" + - "bad quality,worst quality" + + + diff --git a/configs/prompts/4-MajicMix.yaml b/configs/prompts/4-MajicMix.yaml deleted file mode 100644 index 6680bd5..0000000 --- a/configs/prompts/4-MajicMix.yaml +++ /dev/null @@ -1,23 +0,0 @@ -MajicMix: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors" - lora_model_path: "" - - seed: [1572448948722921032, 1099474677988590681, 6488833139725635347, 18339859844376517918] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic" - - "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting" - - "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below" - - "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic" - - n_prompt: - - "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles" - - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - - "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome" - - "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people" diff --git a/configs/prompts/5-RealisticVision.yaml b/configs/prompts/5-RealisticVision.yaml deleted file mode 100644 index 520619c..0000000 --- a/configs/prompts/5-RealisticVision.yaml +++ /dev/null @@ -1,23 +0,0 @@ -RealisticVision: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: [5658137986800322009, 12099779162349365895, 10499524853910852697, 16768009035333711932] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" - - n_prompt: - - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" diff --git a/configs/prompts/6-Tusun.yaml b/configs/prompts/6-Tusun.yaml deleted file mode 100644 index 1275a20..0000000 --- a/configs/prompts/6-Tusun.yaml +++ /dev/null @@ -1,21 +0,0 @@ -Tusun: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/moonfilm_reality20.safetensors" - lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors" - lora_alpha: 0.6 - - seed: [10154078483724687116, 2664393535095473805, 4231566096207622938, 1713349740448094493] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - - "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - - "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing" - - "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body" - - n_prompt: - - "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative" diff --git a/configs/prompts/7-FilmVelvia.yaml b/configs/prompts/7-FilmVelvia.yaml deleted file mode 100644 index 46cc163..0000000 --- a/configs/prompts/7-FilmVelvia.yaml +++ /dev/null @@ -1,24 +0,0 @@ -FilmVelvia: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors" - lora_model_path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors" - lora_alpha: 0.6 - - seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name" - - ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir" - - "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark" - - "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, " - - n_prompt: - - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - - "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" - - "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg" diff --git a/configs/prompts/8-GhibliBackground.yaml b/configs/prompts/8-GhibliBackground.yaml deleted file mode 100644 index 45145db..0000000 --- a/configs/prompts/8-GhibliBackground.yaml +++ /dev/null @@ -1,21 +0,0 @@ -GhibliBackground: - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/CounterfeitV30_25.safetensors" - lora_model_path: "models/DreamBooth_LoRA/lora_Ghibli_n3.safetensors" - lora_alpha: 1.0 - - seed: [8775748474469046618, 5893874876080607656, 11911465742147695752, 12437784838692000640] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall" - - "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter" - - ",mysterious sea area, fantasy,build,concept" - - "Tomb Raider,Scenography,Old building" - - n_prompt: - - "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality" diff --git a/configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml b/configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml deleted file mode 100644 index 37d1aed..0000000 --- a/configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml +++ /dev/null @@ -1,189 +0,0 @@ -ZoomIn: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -ZoomOut: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -PanLeft: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_PanLeft.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -PanRight: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_PanRight.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -TiltUp: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_TiltUp.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -TiltDown: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_TiltDown.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -RollingAnticlockwise: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - - -RollingClockwise: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - motion_module_lora_configs: - - path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt" - alpha: 1.0 - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: 45987230 - steps: 25 - guidance_scale: 7.5 - - prompt: - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - n_prompt: - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" diff --git a/configs/prompts/v2/5-RealisticVision.yaml b/configs/prompts/v2/5-RealisticVision.yaml deleted file mode 100644 index 423b64f..0000000 --- a/configs/prompts/v2/5-RealisticVision.yaml +++ /dev/null @@ -1,23 +0,0 @@ -RealisticVision: - inference_config: "configs/inference/inference-v2.yaml" - motion_module: - - "models/Motion_Module/mm_sd_v15_v2.ckpt" - - dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors" - lora_model_path: "" - - seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302] - steps: 25 - guidance_scale: 7.5 - - prompt: - - "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot" - - "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3" - - "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain" - - n_prompt: - - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - - "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck" - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" - - "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation" diff --git a/configs/training/image_finetune.yaml b/configs/training/image_finetune.yaml deleted file mode 100644 index ea05fd1..0000000 --- a/configs/training/image_finetune.yaml +++ /dev/null @@ -1,48 +0,0 @@ -image_finetune: true - -output_dir: "outputs" -pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" - -noise_scheduler_kwargs: - num_train_timesteps: 1000 - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "scaled_linear" - steps_offset: 1 - clip_sample: false - -train_data: - csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv" - video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val" - sample_size: 256 - -validation_data: - prompts: - - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." - - "A drone view of celebration with Christma tree and fireworks, starry sky - background." - - "Robot dancing in times square." - - "Pacific coast, carmel by the sea ocean and waves." - num_inference_steps: 25 - guidance_scale: 8. - -trainable_modules: - - "." - -unet_checkpoint_path: "" - -learning_rate: 1.e-5 -train_batch_size: 50 - -max_train_epoch: -1 -max_train_steps: 100 -checkpointing_epochs: -1 -checkpointing_steps: 60 - -validation_steps: 5000 -validation_steps_tuple: [2, 50] - -global_seed: 42 -mixed_precision_training: true -enable_xformers_memory_efficient_attention: True - -is_debug: False diff --git a/configs/training/training.yaml b/configs/training/training.yaml deleted file mode 100644 index 626f05c..0000000 --- a/configs/training/training.yaml +++ /dev/null @@ -1,66 +0,0 @@ -image_finetune: false - -output_dir: "outputs" -pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5" - -unet_additional_kwargs: - use_motion_module : true - motion_module_resolutions : [ 1,2,4,8 ] - unet_use_cross_frame_attention : false - unet_use_temporal_attention : false - - motion_module_type: Vanilla - motion_module_kwargs: - num_attention_heads : 8 - num_transformer_block : 1 - attention_block_types : [ "Temporal_Self", "Temporal_Self" ] - temporal_position_encoding : true - temporal_position_encoding_max_len : 24 - temporal_attention_dim_div : 1 - zero_initialize : true - -noise_scheduler_kwargs: - num_train_timesteps: 1000 - beta_start: 0.00085 - beta_end: 0.012 - beta_schedule: "linear" - steps_offset: 1 - clip_sample: false - -train_data: - csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv" - video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val" - sample_size: 256 - sample_stride: 4 - sample_n_frames: 16 - -validation_data: - prompts: - - "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons." - - "A drone view of celebration with Christma tree and fireworks, starry sky - background." - - "Robot dancing in times square." - - "Pacific coast, carmel by the sea ocean and waves." - num_inference_steps: 25 - guidance_scale: 8. - -trainable_modules: - - "motion_modules." - -unet_checkpoint_path: "" - -learning_rate: 1.e-4 -train_batch_size: 4 - -max_train_epoch: -1 -max_train_steps: 100 -checkpointing_epochs: -1 -checkpointing_steps: 60 - -validation_steps: 5000 -validation_steps_tuple: [2, 50] - -global_seed: 42 -mixed_precision_training: true -enable_xformers_memory_efficient_attention: True - -is_debug: False diff --git a/download_bashscripts/0-MotionModule.sh b/download_bashscripts/0-MotionModule.sh index 8e2007e..a115866 100644 --- a/download_bashscripts/0-MotionModule.sh +++ b/download_bashscripts/0-MotionModule.sh @@ -1,2 +1 @@ -gdown 1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq -O models/Motion_Module/ -gdown 1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu -O models/Motion_Module/ \ No newline at end of file +gdown 1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX -O models/Motion_Module/ diff --git a/download_bashscripts/1-DynaVision.sh b/download_bashscripts/1-DynaVision.sh new file mode 100644 index 0000000..4516ba4 --- /dev/null +++ b/download_bashscripts/1-DynaVision.sh @@ -0,0 +1,2 @@ +#!/bin/bash +wget https://civitai.com/api/download/models/169718 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/1-ToonYou.sh b/download_bashscripts/1-ToonYou.sh deleted file mode 100644 index 6b7c3b6..0000000 --- a/download_bashscripts/1-ToonYou.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/78775 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/2-DreamShaper.sh b/download_bashscripts/2-DreamShaper.sh new file mode 100644 index 0000000..ee528bc --- /dev/null +++ b/download_bashscripts/2-DreamShaper.sh @@ -0,0 +1,2 @@ +#!/bin/bash +wget https://civitai.com/api/download/models/126688 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/2-Lyriel.sh b/download_bashscripts/2-Lyriel.sh deleted file mode 100644 index f4f0215..0000000 --- a/download_bashscripts/2-Lyriel.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/72396 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/3-DeepBlue.sh b/download_bashscripts/3-DeepBlue.sh new file mode 100644 index 0000000..6b590d9 --- /dev/null +++ b/download_bashscripts/3-DeepBlue.sh @@ -0,0 +1,2 @@ +#!/bin/bash +wget https://civitai.com/api/download/models/189102 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/3-RcnzCartoon.sh b/download_bashscripts/3-RcnzCartoon.sh deleted file mode 100644 index 07f4f69..0000000 --- a/download_bashscripts/3-RcnzCartoon.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/71009 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/4-MajicMix.sh b/download_bashscripts/4-MajicMix.sh deleted file mode 100644 index b287167..0000000 --- a/download_bashscripts/4-MajicMix.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/79068 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/5-RealisticVision.sh b/download_bashscripts/5-RealisticVision.sh deleted file mode 100644 index bd7f6f2..0000000 --- a/download_bashscripts/5-RealisticVision.sh +++ /dev/null @@ -1,2 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/29460 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate \ No newline at end of file diff --git a/download_bashscripts/6-Tusun.sh b/download_bashscripts/6-Tusun.sh deleted file mode 100644 index 9fc18b3..0000000 --- a/download_bashscripts/6-Tusun.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/97261 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate -wget https://civitai.com/api/download/models/50705 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate diff --git a/download_bashscripts/7-FilmVelvia.sh b/download_bashscripts/7-FilmVelvia.sh deleted file mode 100644 index 53aa688..0000000 --- a/download_bashscripts/7-FilmVelvia.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/90115 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate -wget https://civitai.com/api/download/models/92475 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate diff --git a/download_bashscripts/8-GhibliBackground.sh b/download_bashscripts/8-GhibliBackground.sh deleted file mode 100644 index 39b9e76..0000000 --- a/download_bashscripts/8-GhibliBackground.sh +++ /dev/null @@ -1,3 +0,0 @@ -#!/bin/bash -wget https://civitai.com/api/download/models/102828 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate -wget https://civitai.com/api/download/models/57618 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate diff --git a/environment.yaml b/environment.yaml index 64d18c7..6fd4182 100644 --- a/environment.yaml +++ b/environment.yaml @@ -1,23 +1,34 @@ -name: animatediff +name: animatediff_xl channels: - pytorch - nvidia + - defaults dependencies: - - python=3.10 - - pytorch=1.13.1 - - torchvision=0.14.1 - - torchaudio=0.13.1 + - numpy=1.25.0 + - pillow=9.4.0 + - python=3.10.12 - pytorch-cuda=11.7 - pip - pip: - - diffusers==0.11.1 - - transformers==4.25.1 - - xformers==0.0.16 - - imageio==2.27.0 - - decord==0.6.0 - - gdown - - einops - - omegaconf - - safetensors - - gradio - - wandb + - absl-py==1.4.0 + - accelerate==0.21.0 + - av==10.0.0 + - beautifulsoup4==4.12.2 + - bitsandbytes==0.41.1 + - colorama==0.4.4 + - decord==0.6.0 + - diffusers==0.20.2 + - easydict==1.10 + - einops==0.7.0rc1 + - gdown==4.7.1 + - imageio==2.27.0 + - omegaconf==2.3.0 + - torch==2.1.0 + - torchaudio==2.1.0 + - torchvision==0.16.0 + - tqdm==4.65.0 + - transformers==4.30.0 + - wandb==0.15.8 + - xformers==0.0.22.post4 + - scipy + - imageio[ffmpeg] diff --git a/models/MotionLoRA/Put MotionLoRA checkpoints here.txt b/models/MotionLoRA/Put MotionLoRA checkpoints here.txt deleted file mode 100644 index e69de29..0000000 diff --git a/models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt b/models/StableDiffusion/Put diffusers stable-diffusion-v1-5 repo here.txt deleted file mode 100644 index e69de29..0000000 diff --git a/__assets__/animations/compare/ffmpeg b/models/StableDiffusion/Put diffusers stable-diffusion-xl repo here.txt similarity index 100% rename from __assets__/animations/compare/ffmpeg rename to models/StableDiffusion/Put diffusers stable-diffusion-xl repo here.txt