diff --git a/README.md b/README.md index aaa6525..4224071 100644 --- a/README.md +++ b/README.md @@ -1,121 +1,10 @@ -# AnimateDiff +# A beta-version of motion module for SDXL -This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725). - -**[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)** - -Yuwei Guo, -Ceyuan Yang*, -Anyi Rao, -Yaohui Wang, -Yu Qiao, -Dahua Lin, -Bo Dai -
*Corresponding Author
- - -[](https://arxiv.org/abs/2307.04725) -[](https://animatediff.github.io/) -[](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff) -[](https://huggingface.co/spaces/guoyww/AnimateDiff) - -## Next -One with better controllability and quality is coming soon. Stay tuned. - -## Features -- **[2023/11/10]** Release the Motion Module (beta version) on SDXL, available at [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link -) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt -) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). High resolution videos (i.e., 1024x1024x16 frames with various aspect ratios) could be produced **with/without** personalized models. Inference usually requires ~13GB VRAM and tuned hyperparameters (e.g., #sampling steps), depending on the chosen personalized models. Checkout to the branch `sdxl` for more details of the inference. More checkpoints with better-quality would be available soon. Stay tuned. Examples below are manually downsampled for fast loading. - -| Original SDXL | -Personalized SDXL | -Personalized SDXL | -
![]() |
- ![]() |
- ![]() |
-
| Zoom In | -Zoom Out | -Zoom Pan Left | -Zoom Pan Right | -||||
![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
-
| Tilt Up | -Tilt Down | -Rolling Anti-Clockwise | -Rolling Clockwise | -||||
![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
-
![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
-
+Now you can generate high-resolution videos on SDXL **with/without** personalized models. Checkpoint with better quality would be available soon. Stay tuned.
+## Somethings Important
+- Generate videos with high-resolution (we provide recommended ones) as SDXL usually leads to worse quality for low-resolution images.
+- Follow and slightly adjust the hyperparameters (e.g., #sampling steps, #guidance scale) of various personalized SDXL since these models are carefully tuned to various extent.
## Model Zoo
dev branch is for community contributions. As for the main branch, we would like to align it with the original technical report :)
-![]() |
+ ![]() |
+
![]() |
+ ![]() |
+ ![]() |
+
Model:DynaVision
+ +
![]() |
+ ![]() |
+
Model:DreamShaper
+
![]() |
+ ![]() |
+ ![]() |
+
Model:DeepBlue
+ + +## Inference Example + +Inference at recommended resolution of 16 frames usually requires ~13GB VRAM. +### Step-1: Prepare Environment -## BibTeX ``` -@article{guo2023animatediff, - title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning}, - author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Wang, Yaohui and Qiao, Yu and Lin, Dahua and Dai, Bo}, - journal={arXiv preprint arXiv:2307.04725}, - year={2023} -} +git clone https://github.com/guoyww/AnimateDiff.git +cd AnimateDiff +git checkout sdxl + + +conda env create -f environment.yaml +conda activate animatediff_xl ``` -## Contact Us -**Yuwei Guo**: [guoyuwei@pjlab.org.cn](mailto:guoyuwei@pjlab.org.cn) -**Ceyuan Yang**: [yangceyuan@pjlab.org.cn](mailto:yangceyuan@pjlab.org.cn) -**Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn) +### Step-2: Download Base T2I & Motion Module Checkpoints +We provide a beta version of motion module on SDXL. You can download the base model of SDXL 1.0 and Motion Module following instructions below. +``` +git lfs install +git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 models/StableDiffusion/ -## Acknowledgements -Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video). +bash download_bashscripts/0-MotionModule.sh +``` +You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link +) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt +) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder. + +### Step-3: Download Personalized SDXL (you can skip this if generating videos on the original SDXL) +You may run the following bash scripts to download the LoRA checkpoint from CivitAI. +``` +bash download_bashscripts/1-DynaVision.sh +bash download_bashscripts/2-DreamShaper.sh +bash download_bashscripts/3-DeepBlue.sh +``` + +### Step-4: Generate Videos +Run the following commands to generate videos of **original SDXL**. +``` +python -m scripts.animate --exp_config configs/prompts/1-original_sdxl.yaml --H 1024 --W 1024 --L 16 --xformers +``` +Run the following commands to generate videos of **personalized SDXL**. DO NOT skip Step-3. +``` +python -m scripts.animate --config configs/prompts/2-DynaVision.yaml --H 1024 --W 1024 --L 16 --xformers +python -m scripts.animate --config configs/prompts/3-DreamShaper.yaml --H 1024 --W 1024 --L 16 --xformers +python -m scripts.animate --config configs/prompts/4-DeepBlue.yaml --H 1024 --W 1024 --L 16 --xformers +``` +The results will automatically be saved to `samples/` folder. + + +## Customized Inference +To generate videos with a new Checkpoint/LoRA model, you may create a new config `.yaml` file in the following format: +``` + +motion_module_path: "models/Motion_Module/mm_sdxl_v10_beta.ckpt" # Specify the Motion Module + +# We support 3 types of T2I models. + # 1. Checkpoint: a safetensors model contains UNet, Text_Encoders, VAE. + # 2. LoRA: a safetensors model contains only the LoRA modules. + # 3. You can convert the Checkpoint into a folder with the same structure as SDXL_1.0 base model. + + +ckpt_path: "YOUR_CKPT_PATH" # path to the checkpoint type model from CivitAI. +lora_path: "YOUR_LORA_PATH" # path to the LORA type model from CivitAI. +base_model_path: "YOUR_BASE_MODEL_PATH" # path to the folder converted from a checkpoint + + +steps: 50 +guidance_scale: 8.5 + +seed: -1 # You can specify seed for each prompt. + +prompt: + - "[positive prompt]" + +n_prompt: + - "[negative prompt]" +``` + +Then run the following commands. +``` +python -m scripts.animate --exp_config [path to the personalized config] --L [video frames] --H [Height of the videos] --W [Width of the videos] --xformers +``` \ No newline at end of file diff --git a/__assets__/animations/compare/new_0.gif b/__assets__/animations/compare/new_0.gif deleted file mode 100644 index 8681fa5..0000000 Binary files a/__assets__/animations/compare/new_0.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_1.gif b/__assets__/animations/compare/new_1.gif deleted file mode 100644 index dd0b296..0000000 Binary files a/__assets__/animations/compare/new_1.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_2.gif b/__assets__/animations/compare/new_2.gif deleted file mode 100644 index 7baeb7b..0000000 Binary files a/__assets__/animations/compare/new_2.gif and /dev/null differ diff --git a/__assets__/animations/compare/new_3.gif b/__assets__/animations/compare/new_3.gif deleted file mode 100644 index 07dc320..0000000 Binary files a/__assets__/animations/compare/new_3.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_0.gif b/__assets__/animations/compare/old_0.gif deleted file mode 100644 index 70709b8..0000000 Binary files a/__assets__/animations/compare/old_0.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_1.gif b/__assets__/animations/compare/old_1.gif deleted file mode 100644 index 5c605be..0000000 Binary files a/__assets__/animations/compare/old_1.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_2.gif b/__assets__/animations/compare/old_2.gif deleted file mode 100644 index 2e20b7b..0000000 Binary files a/__assets__/animations/compare/old_2.gif and /dev/null differ diff --git a/__assets__/animations/compare/old_3.gif b/__assets__/animations/compare/old_3.gif deleted file mode 100644 index b035a95..0000000 Binary files a/__assets__/animations/compare/old_3.gif and /dev/null differ diff --git a/__assets__/animations/model_01/01.gif b/__assets__/animations/model_01/01.gif index 0d869f0..f19dfec 100644 Binary files a/__assets__/animations/model_01/01.gif and b/__assets__/animations/model_01/01.gif differ diff --git a/__assets__/animations/model_01/02.gif b/__assets__/animations/model_01/02.gif index 3924190..480cbc8 100644 Binary files a/__assets__/animations/model_01/02.gif and b/__assets__/animations/model_01/02.gif differ diff --git a/__assets__/animations/model_01/03.gif b/__assets__/animations/model_01/03.gif index 1b1b970..ac3677c 100644 Binary files a/__assets__/animations/model_01/03.gif and b/__assets__/animations/model_01/03.gif differ diff --git a/__assets__/animations/model_01/04.gif b/__assets__/animations/model_01/04.gif deleted file mode 100644 index 593a500..0000000 Binary files a/__assets__/animations/model_01/04.gif and /dev/null differ diff --git a/__assets__/animations/model_02/01.gif b/__assets__/animations/model_02/01.gif index 8d905e7..bfb6ada 100644 Binary files a/__assets__/animations/model_02/01.gif and b/__assets__/animations/model_02/01.gif differ diff --git a/__assets__/animations/model_02/02.gif b/__assets__/animations/model_02/02.gif index 734bee4..5d35b48 100644 Binary files a/__assets__/animations/model_02/02.gif and b/__assets__/animations/model_02/02.gif differ diff --git a/__assets__/animations/model_02/03.gif b/__assets__/animations/model_02/03.gif deleted file mode 100644 index 228bd2a..0000000 Binary files a/__assets__/animations/model_02/03.gif and /dev/null differ diff --git a/__assets__/animations/model_02/04.gif b/__assets__/animations/model_02/04.gif deleted file mode 100644 index fb9f69e..0000000 Binary files a/__assets__/animations/model_02/04.gif and /dev/null differ diff --git a/__assets__/animations/model_03/01.gif b/__assets__/animations/model_03/01.gif index 32f0379..4e54f5f 100644 Binary files a/__assets__/animations/model_03/01.gif and b/__assets__/animations/model_03/01.gif differ diff --git a/__assets__/animations/model_03/02.gif b/__assets__/animations/model_03/02.gif index 42de8ef..44b8827 100644 Binary files a/__assets__/animations/model_03/02.gif and b/__assets__/animations/model_03/02.gif differ diff --git a/__assets__/animations/model_03/03.gif b/__assets__/animations/model_03/03.gif index 28a0446..1f9e2fb 100644 Binary files a/__assets__/animations/model_03/03.gif and b/__assets__/animations/model_03/03.gif differ diff --git a/__assets__/animations/model_03/04.gif b/__assets__/animations/model_03/04.gif deleted file mode 100644 index a920f12..0000000 Binary files a/__assets__/animations/model_03/04.gif and /dev/null differ diff --git a/__assets__/animations/model_04/01.gif b/__assets__/animations/model_04/01.gif deleted file mode 100644 index 492f0cc..0000000 Binary files a/__assets__/animations/model_04/01.gif and /dev/null differ diff --git a/__assets__/animations/model_04/02.gif b/__assets__/animations/model_04/02.gif deleted file mode 100644 index 97bad3c..0000000 Binary files a/__assets__/animations/model_04/02.gif and /dev/null differ diff --git a/__assets__/animations/model_04/03.gif b/__assets__/animations/model_04/03.gif deleted file mode 100644 index 11f1855..0000000 Binary files a/__assets__/animations/model_04/03.gif and /dev/null differ diff --git a/__assets__/animations/model_04/04.gif b/__assets__/animations/model_04/04.gif deleted file mode 100644 index 0e81050..0000000 Binary files a/__assets__/animations/model_04/04.gif and /dev/null differ diff --git a/__assets__/animations/model_05/01.gif b/__assets__/animations/model_05/01.gif deleted file mode 100644 index 782f03a..0000000 Binary files a/__assets__/animations/model_05/01.gif and /dev/null differ diff --git a/__assets__/animations/model_05/02.gif b/__assets__/animations/model_05/02.gif deleted file mode 100644 index 2197447..0000000 Binary files a/__assets__/animations/model_05/02.gif and /dev/null differ diff --git a/__assets__/animations/model_05/03.gif b/__assets__/animations/model_05/03.gif deleted file mode 100644 index 44d922d..0000000 Binary files a/__assets__/animations/model_05/03.gif and /dev/null differ diff --git a/__assets__/animations/model_05/04.gif b/__assets__/animations/model_05/04.gif deleted file mode 100644 index 43e859f..0000000 Binary files a/__assets__/animations/model_05/04.gif and /dev/null differ diff --git a/__assets__/animations/model_06/01.gif b/__assets__/animations/model_06/01.gif deleted file mode 100644 index 123d2a6..0000000 Binary files a/__assets__/animations/model_06/01.gif and /dev/null differ diff --git a/__assets__/animations/model_06/02.gif b/__assets__/animations/model_06/02.gif deleted file mode 100644 index fd64b6e..0000000 Binary files a/__assets__/animations/model_06/02.gif and /dev/null differ diff --git a/__assets__/animations/model_06/03.gif b/__assets__/animations/model_06/03.gif deleted file mode 100644 index 53cbb28..0000000 Binary files a/__assets__/animations/model_06/03.gif and /dev/null differ diff --git a/__assets__/animations/model_06/04.gif b/__assets__/animations/model_06/04.gif deleted file mode 100644 index ddd0e01..0000000 Binary files a/__assets__/animations/model_06/04.gif and /dev/null differ diff --git a/__assets__/animations/model_07/01.gif b/__assets__/animations/model_07/01.gif deleted file mode 100644 index be0eaf4..0000000 Binary files a/__assets__/animations/model_07/01.gif and /dev/null differ diff --git a/__assets__/animations/model_07/02.gif b/__assets__/animations/model_07/02.gif deleted file mode 100644 index bbcad05..0000000 Binary files a/__assets__/animations/model_07/02.gif and /dev/null differ diff --git a/__assets__/animations/model_07/03.gif b/__assets__/animations/model_07/03.gif deleted file mode 100644 index 447b70b..0000000 Binary files a/__assets__/animations/model_07/03.gif and /dev/null differ diff --git a/__assets__/animations/model_07/04.gif b/__assets__/animations/model_07/04.gif deleted file mode 100644 index 2c175ff..0000000 Binary files a/__assets__/animations/model_07/04.gif and /dev/null differ diff --git a/__assets__/animations/model_07/init.jpg b/__assets__/animations/model_07/init.jpg deleted file mode 100644 index 3b83812..0000000 Binary files a/__assets__/animations/model_07/init.jpg and /dev/null differ diff --git a/__assets__/animations/model_08/01.gif b/__assets__/animations/model_08/01.gif deleted file mode 100644 index e183a8b..0000000 Binary files a/__assets__/animations/model_08/01.gif and /dev/null differ diff --git a/__assets__/animations/model_08/02.gif b/__assets__/animations/model_08/02.gif deleted file mode 100644 index 326ef26..0000000 Binary files a/__assets__/animations/model_08/02.gif and /dev/null differ diff --git a/__assets__/animations/model_08/03.gif b/__assets__/animations/model_08/03.gif deleted file mode 100644 index cb0c4a8..0000000 Binary files a/__assets__/animations/model_08/03.gif and /dev/null differ diff --git a/__assets__/animations/model_08/04.gif b/__assets__/animations/model_08/04.gif deleted file mode 100644 index a062247..0000000 Binary files a/__assets__/animations/model_08/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/01.gif b/__assets__/animations/model_original/01.gif similarity index 100% rename from __assets__/animations/motion_xl/01.gif rename to __assets__/animations/model_original/01.gif diff --git a/__assets__/animations/model_original/02.gif b/__assets__/animations/model_original/02.gif new file mode 100644 index 0000000..57bf14e Binary files /dev/null and b/__assets__/animations/model_original/02.gif differ diff --git a/__assets__/animations/motion_lora/model_01/01.gif b/__assets__/animations/motion_lora/model_01/01.gif deleted file mode 100644 index a4e85ea..0000000 Binary files a/__assets__/animations/motion_lora/model_01/01.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/02.gif b/__assets__/animations/motion_lora/model_01/02.gif deleted file mode 100644 index 5b33305..0000000 Binary files a/__assets__/animations/motion_lora/model_01/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/03.gif b/__assets__/animations/motion_lora/model_01/03.gif deleted file mode 100644 index 16a5ae2..0000000 Binary files a/__assets__/animations/motion_lora/model_01/03.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/04.gif b/__assets__/animations/motion_lora/model_01/04.gif deleted file mode 100644 index 73b1d6a..0000000 Binary files a/__assets__/animations/motion_lora/model_01/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/05.gif b/__assets__/animations/motion_lora/model_01/05.gif deleted file mode 100644 index 9fb3661..0000000 Binary files a/__assets__/animations/motion_lora/model_01/05.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/06.gif b/__assets__/animations/motion_lora/model_01/06.gif deleted file mode 100644 index aa18797..0000000 Binary files a/__assets__/animations/motion_lora/model_01/06.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/07.gif b/__assets__/animations/motion_lora/model_01/07.gif deleted file mode 100644 index 8308862..0000000 Binary files a/__assets__/animations/motion_lora/model_01/07.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_01/08.gif b/__assets__/animations/motion_lora/model_01/08.gif deleted file mode 100644 index 75ba8fa..0000000 Binary files a/__assets__/animations/motion_lora/model_01/08.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/01.gif b/__assets__/animations/motion_lora/model_02/01.gif deleted file mode 100644 index 36db48e..0000000 Binary files a/__assets__/animations/motion_lora/model_02/01.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/02.gif b/__assets__/animations/motion_lora/model_02/02.gif deleted file mode 100644 index ead0fd4..0000000 Binary files a/__assets__/animations/motion_lora/model_02/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/03.gif b/__assets__/animations/motion_lora/model_02/03.gif deleted file mode 100644 index 1b5136b..0000000 Binary files a/__assets__/animations/motion_lora/model_02/03.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/04.gif b/__assets__/animations/motion_lora/model_02/04.gif deleted file mode 100644 index b409fd9..0000000 Binary files a/__assets__/animations/motion_lora/model_02/04.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/05.gif b/__assets__/animations/motion_lora/model_02/05.gif deleted file mode 100644 index 5216871..0000000 Binary files a/__assets__/animations/motion_lora/model_02/05.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/06.gif b/__assets__/animations/motion_lora/model_02/06.gif deleted file mode 100644 index 25c7ee7..0000000 Binary files a/__assets__/animations/motion_lora/model_02/06.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/07.gif b/__assets__/animations/motion_lora/model_02/07.gif deleted file mode 100644 index 3b239aa..0000000 Binary files a/__assets__/animations/motion_lora/model_02/07.gif and /dev/null differ diff --git a/__assets__/animations/motion_lora/model_02/08.gif b/__assets__/animations/motion_lora/model_02/08.gif deleted file mode 100644 index d1348e0..0000000 Binary files a/__assets__/animations/motion_lora/model_02/08.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/02.gif b/__assets__/animations/motion_xl/02.gif deleted file mode 100644 index f19dfec..0000000 Binary files a/__assets__/animations/motion_xl/02.gif and /dev/null differ diff --git a/__assets__/animations/motion_xl/03.gif b/__assets__/animations/motion_xl/03.gif deleted file mode 100644 index 4e54f5f..0000000 Binary files a/__assets__/animations/motion_xl/03.gif and /dev/null differ diff --git a/__assets__/docs/animatediff.md b/__assets__/docs/animatediff.md deleted file mode 100644 index 6e1f26b..0000000 --- a/__assets__/docs/animatediff.md +++ /dev/null @@ -1,112 +0,0 @@ -# AnimateDiff: training and inference setup -## Setups for Inference - -### Prepare Environment - -***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!*** - -``` -git clone https://github.com/guoyww/AnimateDiff.git -cd AnimateDiff - -conda env create -f environment.yaml -conda activate animatediff -``` - -### Download Base T2I & Motion Module Checkpoints -We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately. -It's recommanded to try both of them for best results. -``` -git lfs install -git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/ - -bash download_bashscripts/0-MotionModule.sh -``` -You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder. - -### Prepare Personalize T2I -Here we provide inference configs for 6 demo T2I on CivitAI. -You may run the following bash scripts to download these checkpoints. -``` -bash download_bashscripts/1-ToonYou.sh -bash download_bashscripts/2-Lyriel.sh -bash download_bashscripts/3-RcnzCartoon.sh -bash download_bashscripts/4-MajicMix.sh -bash download_bashscripts/5-RealisticVision.sh -bash download_bashscripts/6-Tusun.sh -bash download_bashscripts/7-FilmVelvia.sh -bash download_bashscripts/8-GhibliBackground.sh -``` - -### Inference -After downloading the above peronalized T2I checkpoints, run the following commands to generate animations. The results will automatically be saved to `samples/` folder. -``` -python -m scripts.animate --config configs/prompts/1-ToonYou.yaml -python -m scripts.animate --config configs/prompts/2-Lyriel.yaml -python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml -python -m scripts.animate --config configs/prompts/4-MajicMix.yaml -python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml -python -m scripts.animate --config configs/prompts/6-Tusun.yaml -python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml -python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml -``` - -To generate animations with a new DreamBooth/LoRA model, you may create a new config `.yaml` file in the following format: -``` -NewModel: - inference_config: "[path to motion module config file]" - - motion_module: - - "models/Motion_Module/mm_sd_v14.ckpt" - - "models/Motion_Module/mm_sd_v15.ckpt" - - motion_module_lora_configs: - - path: "[path to MotionLoRA model]" - alpha: 1.0 - - ... - - dreambooth_path: "[path to your DreamBooth model .safetensors file]" - lora_model_path: "[path to your LoRA model .safetensors file, leave it empty string if not needed]" - - steps: 25 - guidance_scale: 7.5 - - prompt: - - "[positive prompt]" - - n_prompt: - - "[negative prompt]" -``` -Then run the following commands: -``` -python -m scripts.animate --config [path to the config file] -``` - - -## Steps for Training - -### Dataset -Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine. -Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`. - -### Configuration -After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder: -``` -train_data: - csv_path: [Replace with .csv Annotation File Path] - video_folder: [Replace with Video Folder Path] - sample_size: 256 -``` -Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files. - -### Training -To train motion modules -``` -torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml -``` - -To finetune the unet's image layers -``` -torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml -``` - diff --git a/__assets__/docs/gallery.md b/__assets__/docs/gallery.md deleted file mode 100644 index 8891dd2..0000000 --- a/__assets__/docs/gallery.md +++ /dev/null @@ -1,93 +0,0 @@ -# Gallery -Here we demonstrate several best results we found in our experiments. - -
![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model:ToonYou
- -![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model:Counterfeit V3.0
- -![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model:Realistic Vision V2.0
- -![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model: majicMIX Realistic
- -![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model:RCNZ Cartoon
- -![]() |
- ![]() |
- ![]() |
- ![]() |
-
Model:FilmVelvia
- -#### Community Cases -Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚. - -![]() |
- ![]() |
- ![]() |
- ![]() |
- ![]() |
-
-Character Model:Yoimiya -(with an initial reference image, see WIP fork for the extended implementation.) - - -
![]() |
- ![]() |
- ![]() |
- ![]() |
-
-Character Model:Paimon; -Pose Model:Hold Sign
- - diff --git a/__assets__/figs/gradio.jpg b/__assets__/figs/gradio.jpg deleted file mode 100644 index 19aea7c..0000000 Binary files a/__assets__/figs/gradio.jpg and /dev/null differ diff --git a/animatediff/.DS_Store b/animatediff/.DS_Store new file mode 100644 index 0000000..5fd8fa5 Binary files /dev/null and b/animatediff/.DS_Store differ diff --git a/animatediff/data/dataset.py b/animatediff/data/dataset.py deleted file mode 100644 index 3f6ec10..0000000 --- a/animatediff/data/dataset.py +++ /dev/null @@ -1,98 +0,0 @@ -import os, io, csv, math, random -import numpy as np -from einops import rearrange -from decord import VideoReader - -import torch -import torchvision.transforms as transforms -from torch.utils.data.dataset import Dataset -from animatediff.utils.util import zero_rank_print - - - -class WebVid10M(Dataset): - def __init__( - self, - csv_path, video_folder, - sample_size=256, sample_stride=4, sample_n_frames=16, - is_image=False, - ): - zero_rank_print(f"loading annotations from {csv_path} ...") - with open(csv_path, 'r') as csvfile: - self.dataset = list(csv.DictReader(csvfile)) - self.length = len(self.dataset) - zero_rank_print(f"data scale: {self.length}") - - self.video_folder = video_folder - self.sample_stride = sample_stride - self.sample_n_frames = sample_n_frames - self.is_image = is_image - - sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size) - self.pixel_transforms = transforms.Compose([ - transforms.RandomHorizontalFlip(), - transforms.Resize(sample_size[0]), - transforms.CenterCrop(sample_size), - transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True), - ]) - - def get_batch(self, idx): - video_dict = self.dataset[idx] - videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir'] - - video_dir = os.path.join(self.video_folder, f"{videoid}.mp4") - video_reader = VideoReader(video_dir) - video_length = len(video_reader) - - if not self.is_image: - clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1) - start_idx = random.randint(0, video_length - clip_length) - batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int) - else: - batch_index = [random.randint(0, video_length - 1)] - - pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous() - pixel_values = pixel_values / 255. - del video_reader - - if self.is_image: - pixel_values = pixel_values[0] - - return pixel_values, name - - def __len__(self): - return self.length - - def __getitem__(self, idx): - while True: - try: - pixel_values, name = self.get_batch(idx) - break - - except Exception as e: - idx = random.randint(0, self.length-1) - - pixel_values = self.pixel_transforms(pixel_values) - sample = dict(pixel_values=pixel_values, text=name) - return sample - - - -if __name__ == "__main__": - from animatediff.utils.util import save_videos_grid - - dataset = WebVid10M( - csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv", - video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val", - sample_size=256, - sample_stride=4, sample_n_frames=16, - is_image=True, - ) - import pdb - pdb.set_trace() - - dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,) - for idx, batch in enumerate(dataloader): - print(batch["pixel_values"].shape, len(batch["text"])) - # for i in range(batch["pixel_values"].shape[0]): - # save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True) diff --git a/animatediff/models/attention.py b/animatediff/models/attention.py deleted file mode 100644 index ad23583..0000000 --- a/animatediff/models/attention.py +++ /dev/null @@ -1,300 +0,0 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py - -from dataclasses import dataclass -from typing import Optional - -import torch -import torch.nn.functional as F -from torch import nn - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin -from diffusers.utils import BaseOutput -from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm - -from einops import rearrange, repeat -import pdb - -@dataclass -class Transformer3DModelOutput(BaseOutput): - sample: torch.FloatTensor - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None - - -class Transformer3DModel(ModelMixin, ConfigMixin): - @register_to_config - def __init__( - self, - num_attention_heads: int = 16, - attention_head_dim: int = 88, - in_channels: Optional[int] = None, - num_layers: int = 1, - dropout: float = 0.0, - norm_num_groups: int = 32, - cross_attention_dim: Optional[int] = None, - attention_bias: bool = False, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - use_linear_projection: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - ): - super().__init__() - self.use_linear_projection = use_linear_projection - self.num_attention_heads = num_attention_heads - self.attention_head_dim = attention_head_dim - inner_dim = num_attention_heads * attention_head_dim - - # Define input layers - self.in_channels = in_channels - - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - if use_linear_projection: - self.proj_in = nn.Linear(in_channels, inner_dim) - else: - self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) - - # Define transformers blocks - self.transformer_blocks = nn.ModuleList( - [ - BasicTransformerBlock( - inner_dim, - num_attention_heads, - attention_head_dim, - dropout=dropout, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - num_embeds_ada_norm=num_embeds_ada_norm, - attention_bias=attention_bias, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - for d in range(num_layers) - ] - ) - - # 4. Define output layers - if use_linear_projection: - self.proj_out = nn.Linear(in_channels, inner_dim) - else: - self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): - # Input - assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." - video_length = hidden_states.shape[2] - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") - encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) - - batch, channel, height, weight = hidden_states.shape - residual = hidden_states - - hidden_states = self.norm(hidden_states) - if not self.use_linear_projection: - hidden_states = self.proj_in(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - else: - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) - - # Blocks - for block in self.transformer_blocks: - hidden_states = block( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=timestep, - video_length=video_length - ) - - # Output - if not self.use_linear_projection: - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - hidden_states = self.proj_out(hidden_states) - else: - hidden_states = self.proj_out(hidden_states) - hidden_states = ( - hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() - ) - - output = hidden_states + residual - - output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) - if not return_dict: - return (output,) - - return Transformer3DModelOutput(sample=output) - - -class BasicTransformerBlock(nn.Module): - def __init__( - self, - dim: int, - num_attention_heads: int, - attention_head_dim: int, - dropout=0.0, - cross_attention_dim: Optional[int] = None, - activation_fn: str = "geglu", - num_embeds_ada_norm: Optional[int] = None, - attention_bias: bool = False, - only_cross_attention: bool = False, - upcast_attention: bool = False, - - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, - ): - super().__init__() - self.only_cross_attention = only_cross_attention - self.use_ada_layer_norm = num_embeds_ada_norm is not None - self.unet_use_cross_frame_attention = unet_use_cross_frame_attention - self.unet_use_temporal_attention = unet_use_temporal_attention - - # SC-Attn - assert unet_use_cross_frame_attention is not None - if unet_use_cross_frame_attention: - self.attn1 = SparseCausalAttention2D( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - cross_attention_dim=cross_attention_dim if only_cross_attention else None, - upcast_attention=upcast_attention, - ) - else: - self.attn1 = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - # Cross-Attn - if cross_attention_dim is not None: - self.attn2 = CrossAttention( - query_dim=dim, - cross_attention_dim=cross_attention_dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - else: - self.attn2 = None - - if cross_attention_dim is not None: - self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - else: - self.norm2 = None - - # Feed-forward - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) - self.norm3 = nn.LayerNorm(dim) - - # Temp-Attn - assert unet_use_temporal_attention is not None - if unet_use_temporal_attention: - self.attn_temp = CrossAttention( - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - ) - nn.init.zeros_(self.attn_temp.to_out[0].weight.data) - self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) - - def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if not is_xformers_available(): - print("Here is how to install it") - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" - " available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - ) - except Exception as e: - raise e - self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - if self.attn2 is not None: - self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - - def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): - # SparseCausal-Attention - norm_hidden_states = ( - self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) - ) - - # if self.only_cross_attention: - # hidden_states = ( - # self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states - # ) - # else: - # hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states - - # pdb.set_trace() - if self.unet_use_cross_frame_attention: - hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states - else: - hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states - - if self.attn2 is not None: - # Cross-Attention - norm_hidden_states = ( - self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) - ) - hidden_states = ( - self.attn2( - norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask - ) - + hidden_states - ) - - # Feed-forward - hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states - - # Temporal-Attention - if self.unet_use_temporal_attention: - d = hidden_states.shape[1] - hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) - norm_hidden_states = ( - self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) - ) - hidden_states = self.attn_temp(norm_hidden_states) + hidden_states - hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) - - return hidden_states diff --git a/animatediff/models/motion_module.py b/animatediff/models/motion_module.py index 2359e71..2638a50 100644 --- a/animatediff/models/motion_module.py +++ b/animatediff/models/motion_module.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union import torch import numpy as np @@ -8,324 +8,418 @@ from torch import nn import torchvision from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from diffusers.utils.import_utils import is_xformers_available -from diffusers.models.attention import CrossAttention, FeedForward +from diffusers.models.attention_processor import Attention +from diffusers.models.attention import FeedForward + +from animatediff.utils.util import zero_rank_print from einops import rearrange, repeat -import math +import math, pdb +import random def zero_module(module): - # Zero out the parameters of a module and return it. - for p in module.parameters(): - p.detach().zero_() - return module + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module @dataclass class TemporalTransformer3DModelOutput(BaseOutput): - sample: torch.FloatTensor - - -if is_xformers_available(): - import xformers - import xformers.ops -else: - xformers = None + sample: torch.FloatTensor def get_motion_module( - in_channels, - motion_module_type: str, - motion_module_kwargs: dict + in_channels, + motion_module_type: str, + motion_module_kwargs: dict ): - if motion_module_type == "Vanilla": - return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) - else: - raise ValueError - + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs) + elif motion_module_type == "Conv": + return ConvTemporalModule(in_channels=in_channels, **motion_module_kwargs) + else: + raise ValueError class VanillaTemporalModule(nn.Module): - def __init__( - self, - in_channels, - num_attention_heads = 8, - num_transformer_block = 2, - attention_block_types =( "Temporal_Self", "Temporal_Self" ), - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - temporal_attention_dim_div = 1, - zero_initialize = True, - ): - super().__init__() - - self.temporal_transformer = TemporalTransformer3DModel( - in_channels=in_channels, - num_attention_heads=num_attention_heads, - attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, - num_layers=num_transformer_block, - attention_block_types=attention_block_types, - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - - if zero_initialize: - self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", ), + spatial_position_encoding = False, + temporal_position_encoding = True, + temporal_position_encoding_max_len = 32, + temporal_attention_dim_div = 1, + zero_initialize = True, + + causal_temporal_attention = False, + causal_temporal_attention_mask_type = "", + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + spatial_position_encoding = spatial_position_encoding, + causal_temporal_attention=causal_temporal_attention, + causal_temporal_attention_mask_type=causal_temporal_attention_mask_type, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) - def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None): - hidden_states = input_tensor - hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) + def forward(self, input_tensor, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask) - output = hidden_states - return output + output = hidden_states + return output -class TemporalTransformer3DModel(nn.Module): - def __init__( - self, - in_channels, - num_attention_heads, - attention_head_dim, +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + spatial_position_encoding = False, + + causal_temporal_attention = None, + causal_temporal_attention_mask_type = "", + ): + super().__init__() + assert causal_temporal_attention is not None + self.causal_temporal_attention = causal_temporal_attention - num_layers, - attention_block_types = ( "Temporal_Self", "Temporal_Self", ), - dropout = 0.0, - norm_num_groups = 32, - cross_attention_dim = 768, - activation_fn = "geglu", - attention_bias = False, - upcast_attention = False, - - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - ): - super().__init__() + assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "") + self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type + self.causal_temporal_attention_mask = None + self.spatial_position_encoding = spatial_position_encoding + inner_dim = num_attention_heads * attention_head_dim - inner_dim = num_attention_heads * attention_head_dim + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + if spatial_position_encoding: + self.pos_encoder_2d = PositionalEncoding2D(inner_dim) + - self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) - self.proj_in = nn.Linear(in_channels, inner_dim) + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def get_causal_temporal_attention_mask(self, hidden_states): + batch_size, sequence_length, dim = hidden_states.shape + + if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (batch_size, sequence_length, sequence_length): + zero_rank_print(f"build attn mask of type {self.causal_temporal_attention_mask_type}") + if self.causal_temporal_attention_mask_type == "causal": + # 1. vanilla causal mask + mask = torch.tril(torch.ones(sequence_length, sequence_length)) - self.transformer_blocks = nn.ModuleList( - [ - TemporalTransformerBlock( - dim=inner_dim, - num_attention_heads=num_attention_heads, - attention_head_dim=attention_head_dim, - attention_block_types=attention_block_types, - dropout=dropout, - norm_num_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - activation_fn=activation_fn, - attention_bias=attention_bias, - upcast_attention=upcast_attention, - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - for d in range(num_layers) - ] - ) - self.proj_out = nn.Linear(inner_dim, in_channels) - - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): - assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." - video_length = hidden_states.shape[2] - hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + elif self.causal_temporal_attention_mask_type == "2-seq": + # 2. 2-seq + mask = torch.zeros(sequence_length, sequence_length) + mask[:sequence_length // 2, :sequence_length // 2] = 1 + mask[-sequence_length // 2:, -sequence_length // 2:] = 1 + + elif self.causal_temporal_attention_mask_type == "0-prev": + # attn to the previous frame + indices = torch.arange(sequence_length) + indices_prev = indices - 1 + indices_prev[0] = 0 + mask = torch.zeros(sequence_length, sequence_length) + mask[:, 0] = 1. + mask[indices, indices_prev] = 1. - batch, channel, height, weight = hidden_states.shape - residual = hidden_states + elif self.causal_temporal_attention_mask_type == "0": + # only attn to first frame + mask = torch.zeros(sequence_length, sequence_length) + mask[:,0] = 1 - hidden_states = self.norm(hidden_states) - inner_dim = hidden_states.shape[1] - hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) - hidden_states = self.proj_in(hidden_states) + elif self.causal_temporal_attention_mask_type == "wo-self": + indices = torch.arange(sequence_length) + mask = torch.ones(sequence_length, sequence_length) + mask[indices, indices] = 0 - # Transformer Blocks - for block in self.transformer_blocks: - hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length) - - # output - hidden_states = self.proj_out(hidden_states) - hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + elif self.causal_temporal_attention_mask_type == "circle": + indices = torch.arange(sequence_length) + indices_prev = indices - 1 + indices_prev[0] = 0 - output = hidden_states + residual - output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) - - return output + mask = torch.eye(sequence_length) + mask[indices, indices_prev] = 1 + mask[0,-1] = 1 + else: raise ValueError + + # for sanity check + if dim == 320: zero_rank_print(mask) + + # generate attention mask fron binary values + mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0)) + mask = mask.unsqueeze(0) + mask = mask.repeat(batch_size, 1, 1) + + self.causal_temporal_attention_mask = mask.to(hidden_states.device) + + return self.causal_temporal_attention_mask + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + residual = hidden_states + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + height, width = hidden_states.shape[-2:] + + hidden_states = self.norm(hidden_states) + + hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c") + hidden_states = self.proj_in(hidden_states) + if self.spatial_position_encoding: + + video_length = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b h w) f c -> (b f) h w c", h=height, w=width) + pos_encoding = self.pos_encoder_2d(hidden_states) + pos_encoding = rearrange(pos_encoding, "(b f) h w c -> (b h w) f c", f = video_length) + hidden_states = rearrange(hidden_states, "(b f) h w c -> (b h w) f c", f=video_length) + + attention_mask = self.get_causal_temporal_attention_mask(hidden_states) if self.causal_temporal_attention else attention_mask + + # Transformer Blocks + for block in self.transformer_blocks: + if not self.spatial_position_encoding : + pos_encoding = None + + hidden_states = block(hidden_states, pos_encoding=pos_encoding, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask) + + hidden_states = self.proj_out(hidden_states) + + hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width) + + output = hidden_states + residual + # output = hidden_states + + return output class TemporalTransformerBlock(nn.Module): - def __init__( - self, - dim, - num_attention_heads, - attention_head_dim, - attention_block_types = ( "Temporal_Self", "Temporal_Self", ), - dropout = 0.0, - norm_num_groups = 32, - cross_attention_dim = 768, - activation_fn = "geglu", - attention_bias = False, - upcast_attention = False, - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - ): - super().__init__() + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + ): + super().__init__() - attention_blocks = [] - norms = [] - - for block_name in attention_block_types: - attention_blocks.append( - VersatileAttention( - attention_mode=block_name.split("_")[0], - cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, - - query_dim=dim, - heads=num_attention_heads, - dim_head=attention_head_dim, - dropout=dropout, - bias=attention_bias, - upcast_attention=upcast_attention, - - cross_frame_attention_mode=cross_frame_attention_mode, - temporal_position_encoding=temporal_position_encoding, - temporal_position_encoding_max_len=temporal_position_encoding_max_len, - ) - ) - norms.append(nn.LayerNorm(dim)) - - self.attention_blocks = nn.ModuleList(attention_blocks) - self.norms = nn.ModuleList(norms) - - self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) - self.ff_norm = nn.LayerNorm(dim) + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + TemporalSelfAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): - for attention_block, norm in zip(self.attention_blocks, self.norms): - norm_hidden_states = norm(hidden_states) - hidden_states = attention_block( - norm_hidden_states, - encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, - video_length=video_length, - ) + hidden_states - - hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states - - output = hidden_states - return output + def forward(self, hidden_states, pos_encoding=None, encoder_hidden_states=None, attention_mask=None): + for attention_block, norm in zip(self.attention_blocks, self.norms): + if pos_encoding is not None: + hidden_states += pos_encoding + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + hidden_states + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +def get_emb(sin_inp): + """ + Gets a base embedding for one dimension with sin and cos intertwined + """ + emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) + return torch.flatten(emb, -2, -1) + +class PositionalEncoding2D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(PositionalEncoding2D, self).__init__() + self.org_channels = channels + channels = int(np.ceil(channels / 4) * 2) + self.channels = channels + inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + self.register_buffer("inv_freq", inv_freq) + self.register_buffer("cached_penc", None) + + def forward(self, tensor): + """ + :param tensor: A 4d tensor of size (batch_size, x, y, ch) + :return: Positional Encoding Matrix of size (batch_size, x, y, ch) + """ + if len(tensor.shape) != 4: + raise RuntimeError("The input tensor has to be 4d!") + + if self.cached_penc is not None and self.cached_penc.shape == tensor.shape: + return self.cached_penc + + self.cached_penc = None + batch_size, x, y, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + emb_x = get_emb(sin_inp_x).unsqueeze(1) + emb_y = get_emb(sin_inp_y) + emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type( + tensor.type() + ) + emb[:, :, : self.channels] = emb_x + emb[:, :, self.channels : 2 * self.channels] = emb_y + + self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1) + return self.cached_penc class PositionalEncoding(nn.Module): - def __init__( - self, - d_model, - dropout = 0., - max_len = 24 - ): - super().__init__() - self.dropout = nn.Dropout(p=dropout) - position = torch.arange(max_len).unsqueeze(1) - div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) - pe = torch.zeros(1, max_len, d_model) - pe[0, :, 0::2] = torch.sin(position * div_term) - pe[0, :, 1::2] = torch.cos(position * div_term) - self.register_buffer('pe', pe) + def __init__( + self, + d_model, + dropout = 0., + max_len = 32, + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) - def forward(self, x): - x = x + self.pe[:, :x.size(1)] - return self.dropout(x) + def forward(self, x): + # if x.size(1) < 16: + # start_idx = random.randint(0, 12) + # else: + # start_idx = 0 + + x = x + self.pe[:, :x.size(1)] + return self.dropout(x) -class VersatileAttention(CrossAttention): - def __init__( - self, - attention_mode = None, - cross_frame_attention_mode = None, - temporal_position_encoding = False, - temporal_position_encoding_max_len = 24, - *args, **kwargs - ): - super().__init__(*args, **kwargs) - assert attention_mode == "Temporal" +class TemporalSelfAttention(Attention): + def __init__( + self, + attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 32, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" - self.attention_mode = attention_mode - self.is_cross_attention = kwargs["cross_attention_dim"] is not None - - self.pos_encoder = PositionalEncoding( - kwargs["query_dim"], - dropout=0., - max_len=temporal_position_encoding_max_len - ) if (temporal_position_encoding and attention_mode == "Temporal") else None + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + max_len=temporal_position_encoding_max_len + ) if temporal_position_encoding else None - def extra_repr(self): - return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + def set_use_memory_efficient_attention_xformers( + self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None + ): + # disable motion module efficient xformers to avoid bad results, don't know why + # TODO: fix this bug + pass - def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): - batch_size, sequence_length, _ = hidden_states.shape + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs): + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty - if self.attention_mode == "Temporal": - d = hidden_states.shape[1] - hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) - - if self.pos_encoder is not None: - hidden_states = self.pos_encoder(hidden_states) - - encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states - else: - raise NotImplementedError + # add position encoding + hidden_states = self.pos_encoder(hidden_states) - encoder_hidden_states = encoder_hidden_states + if hasattr(self.processor, "__call__"): + return self.processor.__call__( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) - if self.group_norm is not None: - hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = self.to_q(hidden_states) - dim = query.shape[-1] - query = self.reshape_heads_to_batch_dim(query) - - if self.added_kv_proj_dim is not None: - raise NotImplementedError - - encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states - key = self.to_k(encoder_hidden_states) - value = self.to_v(encoder_hidden_states) - - key = self.reshape_heads_to_batch_dim(key) - value = self.reshape_heads_to_batch_dim(value) - - if attention_mask is not None: - if attention_mask.shape[-1] != query.shape[1]: - target_length = query.shape[1] - attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) - attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) - - # attention, what we cannot get enough of - if self._use_memory_efficient_attention_xformers: - hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) - # Some versions of xformers return output in fp32, cast it back to the dtype of the input - hidden_states = hidden_states.to(query.dtype) - else: - if self._slice_size is None or query.shape[0] // self._slice_size == 1: - hidden_states = self._attention(query, key, value, attention_mask) - else: - hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) - - # linear proj - hidden_states = self.to_out[0](hidden_states) - - # dropout - hidden_states = self.to_out[1](hidden_states) - - if self.attention_mode == "Temporal": - hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) - - return hidden_states + else: + return self.processor( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) diff --git a/animatediff/models/resnet.py b/animatediff/models/resnet.py deleted file mode 100644 index da80f17..0000000 --- a/animatediff/models/resnet.py +++ /dev/null @@ -1,217 +0,0 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from einops import rearrange - - -class InflatedConv3d(nn.Conv2d): - def forward(self, x): - video_length = x.shape[2] - - x = rearrange(x, "b c f h w -> (b f) c h w") - x = super().forward(x) - x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) - - return x - - -class InflatedGroupNorm(nn.GroupNorm): - def forward(self, x): - video_length = x.shape[2] - - x = rearrange(x, "b c f h w -> (b f) c h w") - x = super().forward(x) - x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) - - return x - - -class Upsample3D(nn.Module): - def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.use_conv_transpose = use_conv_transpose - self.name = name - - conv = None - if use_conv_transpose: - raise NotImplementedError - elif use_conv: - self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) - - def forward(self, hidden_states, output_size=None): - assert hidden_states.shape[1] == self.channels - - if self.use_conv_transpose: - raise NotImplementedError - - # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 - dtype = hidden_states.dtype - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(torch.float32) - - # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 - if hidden_states.shape[0] >= 64: - hidden_states = hidden_states.contiguous() - - # if `output_size` is passed we force the interpolation output - # size and do not make use of `scale_factor=2` - if output_size is None: - hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") - else: - hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") - - # If the input is bfloat16, we cast back to bfloat16 - if dtype == torch.bfloat16: - hidden_states = hidden_states.to(dtype) - - # if self.use_conv: - # if self.name == "conv": - # hidden_states = self.conv(hidden_states) - # else: - # hidden_states = self.Conv2d_0(hidden_states) - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class Downsample3D(nn.Module): - def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): - super().__init__() - self.channels = channels - self.out_channels = out_channels or channels - self.use_conv = use_conv - self.padding = padding - stride = 2 - self.name = name - - if use_conv: - self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) - else: - raise NotImplementedError - - def forward(self, hidden_states): - assert hidden_states.shape[1] == self.channels - if self.use_conv and self.padding == 0: - raise NotImplementedError - - assert hidden_states.shape[1] == self.channels - hidden_states = self.conv(hidden_states) - - return hidden_states - - -class ResnetBlock3D(nn.Module): - def __init__( - self, - *, - in_channels, - out_channels=None, - conv_shortcut=False, - dropout=0.0, - temb_channels=512, - groups=32, - groups_out=None, - pre_norm=True, - eps=1e-6, - non_linearity="swish", - time_embedding_norm="default", - output_scale_factor=1.0, - use_in_shortcut=None, - use_inflated_groupnorm=None, - ): - super().__init__() - self.pre_norm = pre_norm - self.pre_norm = True - self.in_channels = in_channels - out_channels = in_channels if out_channels is None else out_channels - self.out_channels = out_channels - self.use_conv_shortcut = conv_shortcut - self.time_embedding_norm = time_embedding_norm - self.output_scale_factor = output_scale_factor - - if groups_out is None: - groups_out = groups - - assert use_inflated_groupnorm != None - if use_inflated_groupnorm: - self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - else: - self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) - - self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if temb_channels is not None: - if self.time_embedding_norm == "default": - time_emb_proj_out_channels = out_channels - elif self.time_embedding_norm == "scale_shift": - time_emb_proj_out_channels = out_channels * 2 - else: - raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") - - self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) - else: - self.time_emb_proj = None - - if use_inflated_groupnorm: - self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - else: - self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) - - self.dropout = torch.nn.Dropout(dropout) - self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) - - if non_linearity == "swish": - self.nonlinearity = lambda x: F.silu(x) - elif non_linearity == "mish": - self.nonlinearity = Mish() - elif non_linearity == "silu": - self.nonlinearity = nn.SiLU() - - self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut - - self.conv_shortcut = None - if self.use_in_shortcut: - self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) - - def forward(self, input_tensor, temb): - hidden_states = input_tensor - - hidden_states = self.norm1(hidden_states) - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.conv1(hidden_states) - - if temb is not None: - temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] - - if temb is not None and self.time_embedding_norm == "default": - hidden_states = hidden_states + temb - - hidden_states = self.norm2(hidden_states) - - if temb is not None and self.time_embedding_norm == "scale_shift": - scale, shift = torch.chunk(temb, 2, dim=1) - hidden_states = hidden_states * (1 + scale) + shift - - hidden_states = self.nonlinearity(hidden_states) - - hidden_states = self.dropout(hidden_states) - hidden_states = self.conv2(hidden_states) - - if self.conv_shortcut is not None: - input_tensor = self.conv_shortcut(input_tensor) - - output_tensor = (input_tensor + hidden_states) / self.output_scale_factor - - return output_tensor - - -class Mish(torch.nn.Module): - def forward(self, hidden_states): - return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/animatediff/models/unet.py b/animatediff/models/unet.py index 18aa955..0ee1aea 100644 --- a/animatediff/models/unet.py +++ b/animatediff/models/unet.py @@ -1,41 +1,157 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py - +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from dataclasses import dataclass -from typing import List, Optional, Tuple, Union - +from typing import Any, Dict, List, Optional, Tuple, Union import os import json -import pdb + import torch import torch.nn as nn import torch.utils.checkpoint +from einops import rearrange, repeat from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.modeling_utils import ModelMixin +from diffusers.loaders import UNet2DConditionLoadersMixin, AttnProcsLayers from diffusers.utils import BaseOutput, logging -from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import AttentionProcessor, AttnProcessor, LoRAAttnProcessor +from diffusers.models.embeddings import ( + GaussianFourierProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + PositionNet, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin from .unet_blocks import ( - CrossAttnDownBlock3D, - CrossAttnUpBlock3D, - DownBlock3D, UNetMidBlock3DCrossAttn, - UpBlock3D, get_down_block, get_up_block, ) -from .resnet import InflatedConv3d, InflatedGroupNorm +from animatediff.utils.util import zero_rank_print logger = logging.get_logger(__name__) # pylint: disable=invalid-name @dataclass class UNet3DConditionOutput(BaseOutput): - sample: torch.FloatTensor + """ + The output of [`UNet3DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None -class UNet3DConditionModel(ModelMixin, ConfigMixin): +class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be either `UNetMidBlock2DCrossAttn` or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. + conv_out_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_out` layer. + projection_class_embeddings_input_dim (`int`, *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + _supports_gradient_checkpointing = True @register_to_config @@ -46,114 +162,308 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): out_channels: int = 4, center_input_sample: bool = False, flip_sin_to_cos: bool = True, - freq_shift: int = 0, + freq_shift: int = 0, down_block_types: Tuple[str] = ( "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D", ), - mid_block_type: str = "UNetMidBlock3DCrossAttn", - up_block_types: Tuple[str] = ( - "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" - ), + mid_block_type: Optional[str] = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"), only_cross_attention: Union[bool, Tuple[bool]] = False, block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: int = 2, + layers_per_block: Union[int, Tuple[int]] = 2, downsample_padding: int = 1, mid_block_scale_factor: float = 1, act_fn: str = "silu", - norm_num_groups: int = 32, + norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, - cross_attention_dim: int = 1280, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, dual_cross_attention: bool = False, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", - - use_inflated_groupnorm=False, - - # Additional - use_motion_module = False, - motion_module_resolutions = ( 1,2,4,8 ), - motion_module_mid_block = False, - motion_module_decoder_only = False, - motion_module_type = None, - motion_module_kwargs = {}, - unet_use_cross_frame_attention = None, - unet_use_temporal_attention = None, + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + + # motion module + use_motion_module=False, + motion_module_resolutions = (1,2,4,8), + motion_module_mid_block = False, + motion_module_decoder_only = False, + motion_module_type=None, + motion_module_kwargs=None, ): super().__init__() - self.sample_size = sample_size - time_embed_dim = block_out_channels[0] * 4 + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) # input - self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) # time - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) elif class_embed_type == "identity": self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + self.down_blocks = nn.ModuleList([]) - self.mid_block = None self.up_blocks = nn.ModuleList([]) if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + only_cross_attention = [only_cross_attention] * len(down_block_types) + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + if isinstance(attention_head_dim, int): attention_head_dim = (attention_head_dim,) * len(down_block_types) + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + # down output_channel = block_out_channels[0] for i, down_block_type in enumerate(down_block_types): - res = 2 ** i input_channel = output_channel output_channel = block_out_channels[i] is_final_block = i == len(block_out_channels) - 1 - + res = 2 ** i down_block = get_down_block( down_block_type, - num_layers=layers_per_block, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_downsample=not is_final_block, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[i], + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], downsample_padding=downsample_padding, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -163,46 +473,48 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): # mid if mid_block_type == "UNetMidBlock3DCrossAttn": self.mid_block = UNetMidBlock3DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=block_out_channels[-1], - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, resnet_eps=norm_eps, resnet_act_fn=act_fn, output_scale_factor=mid_block_scale_factor, resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attention_head_dim[-1], + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], resnet_groups=norm_num_groups, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module and motion_module_mid_block, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) + elif mid_block_type is None: + self.mid_block = None else: raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the videos + + # count how many layers upsample the images self.num_upsamplers = 0 # up reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_attention_head_dim = list(reversed(attention_head_dim)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block)) only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] for i, up_block_type in enumerate(up_block_types): - res = 2 ** (3 - i) is_final_block = i == len(block_out_channels) - 1 - + res = 2 ** (2 - i) prev_output_channel = output_channel output_channel = reversed_block_out_channels[i] input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - + # add upsample block for all BUT final layer if not is_final_block: add_upsample = True @@ -212,27 +524,28 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): up_block = get_up_block( up_block_type, - num_layers=layers_per_block + 1, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, prev_output_channel=prev_output_channel, - temb_channels=time_embed_dim, + temb_channels=blocks_time_embed_dim, add_upsample=add_upsample, resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim, - attn_num_head_channels=reversed_attention_head_dim[i], + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention[i], upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, use_motion_module=use_motion_module and (res in motion_module_resolutions), motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -241,41 +554,215 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): prev_output_channel = output_channel # out - if use_inflated_groupnorm: - self.conv_norm_out = InflatedGroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + else: - self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) - self.conv_act = nn.SiLU() - self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type == "gated": + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim) + + def set_image_layer_lora(self, image_layer_lora_rank: int = 128): + lora_attn_procs = {} + for name in self.attn_processors.keys(): + zero_rank_print(f"(add lora) {name}") + cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=image_layer_lora_rank if image_layer_lora_rank > 16 else hidden_size // image_layer_lora_rank, + ) + self.set_attn_processor(lora_attn_procs) + + lora_layers = AttnProcsLayers(self.attn_processors) + zero_rank_print(f"(lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M") + del lora_layers + + def set_image_layer_lora_scale(self, lora_scale: float = 1.0): + for block in self.down_blocks: setattr(block, "lora_scale", lora_scale) + for block in self.up_blocks: setattr(block, "lora_scale", lora_scale) + setattr(self.mid_block, "lora_scale", lora_scale) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "set_processor"): + if not "motion_modules." in name: + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], is_motion_module=False): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) if not is_motion_module else len(self.motion_module_attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if ((not is_motion_module) and (not "motion_modules." in name)) or (is_motion_module and ("motion_modules." in name)): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(AttnProcessor()) + + @property + def motion_module_attn_processors(self): + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + # filter out processors in motion module + if hasattr(module, "set_processor"): + if "motion_modules." in name: + processors[f"{name}.processor"] = module.processor + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_motion_module_lora(self, motion_module_lora_rank: int = 256, motion_lora_resolution=[32, 64, 128]): + lora_attn_procs = {} + #motion_name = [] + #if 32 in motion_lora_resolution: + # motion_name.append('up_blocks.0') + # motion_name.append('down_blocks.2') + # if 64 in motion_lora_resolution: + # motion_name.append('up_blocks.1') + # motion_name.append('down_blocks.1') + # if 128 in motion_lora_resolution: + # motion_name.append('up_blocks.2') + # motion_name.append('down_blocks.0') + for name in self.motion_module_attn_processors.keys(): + #prefix = '.'.join(name.split('.')[:2]) + #if prefix not in motion_name: + # continue + print(f"(add motion lora) {name}") + + if name.startswith("mid_block"): + hidden_size = self.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.config.block_out_channels[block_id] + + lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=None, + rank=motion_module_lora_rank, + ) + self.set_attn_processor(lora_attn_procs, is_motion_module=True) + + lora_layers = AttnProcsLayers(self.motion_module_attn_processors) + print(f"(motion lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M") + del lora_layers + + def set_attention_slice(self, slice_size): r""" Enable sliced attention computation. - When this option is enabled, the attention module will split the input tensor in slices, to compute attention - in several steps. This is useful to save some memory in exchange for a small speed decrease. + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. Args: slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): - When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If - `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` must be a multiple of `slice_size`. """ sliceable_head_dims = [] - def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): if hasattr(module, "set_attention_slice"): sliceable_head_dims.append(module.sliceable_head_dim) for child in module.children(): - fn_recursive_retrieve_slicable_dims(child) + fn_recursive_retrieve_sliceable_dims(child) # retrieve number of attention layers for module in self.children(): - fn_recursive_retrieve_slicable_dims(module) + fn_recursive_retrieve_sliceable_dims(module) - num_slicable_layers = len(sliceable_head_dims) + num_sliceable_layers = len(sliceable_head_dims) if slice_size == "auto": # half the attention head size is usually a good trade-off between @@ -283,9 +770,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): slice_size = [dim // 2 for dim in sliceable_head_dims] elif slice_size == "max": # make smallest slice possible - slice_size = num_slicable_layers * [1] + slice_size = num_sliceable_layers * [1] - slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size if len(slice_size) != len(sliceable_head_dims): raise ValueError( @@ -314,7 +801,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): fn_recursive_set_attention_slice(module, reversed_slice_size) def _set_gradient_checkpointing(self, module, value=False): - if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( @@ -323,28 +810,57 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): timestep: Union[torch.Tensor, float, int], encoder_hidden_states: torch.Tensor, class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, return_dict: bool = True, ) -> Union[UNet3DConditionOutput, Tuple]: r""" + The [`UNet2DConditionModel`] forward method. + Args: - sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor - timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps - encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. Returns: [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When - returning a tuple, the first element is the sample tensor. + If `return_dict` is True, an [`~models.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. """ # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). # However, the upsampling interpolation output size can be forced to fit any upsampling size # on the fly if necessary. default_overall_up_factor = 2**self.num_upsamplers + # convert the time, size, and text embedding into (b f) c h w + video_length = sample.shape[2] + timestep = repeat(timestep, "b-> (b f)", f=video_length) + encoder_hidden_states = repeat(encoder_hidden_states, "b c d-> (b f) c d", f=video_length) + added_cond_kwargs['time_ids'] = repeat(added_cond_kwargs['time_ids'], "b c -> (b f) c", f=video_length) + added_cond_kwargs['text_embeds'] = repeat(added_cond_kwargs['text_embeds'], "b c -> (b f) c", f=video_length) + + #sample = rearrange(sample, "b c f h w -> (b f) c h w") + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` forward_upsample_size = False upsample_size = None @@ -353,18 +869,35 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): logger.info("Forward upsample size to force interpolation output size.") forward_upsample_size = True - # prepare attention_mask + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 attention_mask = attention_mask.unsqueeze(1) - # center input if necessary + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary if self.config.center_input_sample: sample = 2 * sample - 1.0 - # time + # 1. time timesteps = timestep if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can # This would be a good case for the `match` statement (Python 3.10+) is_mps = sample.device.type == "mps" if isinstance(timestep, float): @@ -376,15 +909,17 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): timesteps = timesteps[None].to(sample.device) # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) + # timesteps = timesteps t_emb = self.time_proj(timesteps) - # timesteps does not contain any weights and will always return f32 tensors + # `Timesteps` does not contain any weights and will always return f32 tensors # but time_embedding might actually be running in fp16. so we need to cast here. # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=self.dtype) - emb = self.time_embedding(t_emb) + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -393,33 +928,167 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): if self.config.class_embed_type == "timestep": class_labels = self.time_proj(class_labels) - class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) - emb = emb + class_emb + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) - # pre-process + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + video_length = sample.shape[2] + sample = rearrange(sample, "b c f h w -> (b f) c h w") sample = self.conv_in(sample) + sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None - # down down_block_res_samples = (sample,) for downsample_block in self.down_blocks: if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + sample, res_samples = downsample_block( hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, ) else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states) + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) down_block_res_samples += res_samples - # mid - sample = self.mid_block( - sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask - ) + if is_controlnet: + new_down_block_res_samples = () - # up + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up for i, upsample_block in enumerate(self.up_blocks): is_final_block = i == len(self.up_blocks) - 1 @@ -437,24 +1106,32 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): temb=emb, res_hidden_states_tuple=res_samples, encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, upsample_size=upsample_size, attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, ) else: sample = upsample_block( - hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, encoder_hidden_states=encoder_hidden_states, + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size ) - # post-process - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) + + video_length = sample.shape[2] + sample = rearrange(sample, "b c f h w -> (b f) c h w") + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) sample = self.conv_out(sample) + sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: return (sample,) return UNet3DConditionOutput(sample=sample) - + @classmethod def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None): if subfolder is not None: @@ -468,24 +1145,29 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): config = json.load(f) config["_class_name"] = cls.__name__ config["down_block_types"] = [ + "DownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", - "CrossAttnDownBlock3D", - "DownBlock3D" + ] config["up_block_types"] = [ + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", "UpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D", - "CrossAttnUpBlock3D" ] - from diffusers.utils import WEIGHTS_NAME + config["mid_block_type"] = "UNetMidBlock3DCrossAttn" + from diffusers.utils import SAFETENSORS_WEIGHTS_NAME model = cls.from_config(config, **unet_additional_kwargs) - model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME) if not os.path.isfile(model_file): raise RuntimeError(f"{model_file} does not exist") - state_dict = torch.load(model_file, map_location="cpu") + + state_dict = {} + from safetensors import safe_open + with safe_open(model_file, framework='pt') as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) m, u = model.load_state_dict(state_dict, strict=False) print(f"### missing keys: {len(m)}; \n### unexpected keys: {len(u)};") @@ -494,4 +1176,4 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin): params = [p.numel() if "temporal" in n else 0 for n, p in model.named_parameters()] print(f"### Temporal Module Parameters: {sum(params) / 1e6} M") - return model + return model \ No newline at end of file diff --git a/animatediff/models/unet_blocks.py b/animatediff/models/unet_blocks.py index 711ad6c..c7e397a 100644 --- a/animatediff/models/unet_blocks.py +++ b/animatediff/models/unet_blocks.py @@ -1,13 +1,20 @@ -# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +from typing import Any, Dict, Optional, Tuple +import numpy as np import torch +import torch.nn.functional as F +from einops import rearrange from torch import nn -from .attention import Transformer3DModel -from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from diffusers.utils import is_torch_version, logging +from diffusers.models.activations import get_activation +from diffusers.models.attention import AdaGroupNorm +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.resnet import Downsample2D, FirDownsample2D, FirUpsample2D, KDownsample2D, KUpsample2D, ResnetBlock2D, Upsample2D +from diffusers.models.transformer_2d import Transformer2DModel from .motion_module import get_motion_module -import pdb +logger = logging.get_logger(__name__) # pylint: disable=invalid-name def get_down_block( down_block_type, @@ -18,7 +25,8 @@ def get_down_block( add_downsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, downsample_padding=None, @@ -27,16 +35,23 @@ def get_down_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + downsample_type=None, use_motion_module=None, - motion_module_type=None, motion_module_kwargs=None, ): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock3D": return DownBlock3D( @@ -50,18 +65,16 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, - - use_inflated_groupnorm=use_inflated_groupnorm, - use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif down_block_type == "CrossAttnDownBlock3D": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") return CrossAttnDownBlock3D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -71,17 +84,13 @@ def get_down_block( resnet_groups=resnet_groups, downsample_padding=downsample_padding, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, @@ -99,7 +108,8 @@ def get_up_block( add_upsample, resnet_eps, resnet_act_fn, - attn_num_head_channels, + transformer_layers_per_block=1, + num_attention_heads=None, resnet_groups=None, cross_attention_dim=None, dual_cross_attention=False, @@ -107,15 +117,23 @@ def get_up_block( only_cross_attention=False, upcast_attention=False, resnet_time_scale_shift="default", - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", + resnet_skip_time_act=False, + resnet_out_scale_factor=1.0, + cross_attention_norm=None, + attention_head_dim=None, + upsample_type=None, use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, ): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type if up_block_type == "UpBlock3D": return UpBlock3D( @@ -129,18 +147,16 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, resnet_time_scale_shift=resnet_time_scale_shift, - - use_inflated_groupnorm=use_inflated_groupnorm, - use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) elif up_block_type == "CrossAttnUpBlock3D": if cross_attention_dim is None: - raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") return CrossAttnUpBlock3D( num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, in_channels=in_channels, out_channels=out_channels, prev_output_channel=prev_output_channel, @@ -150,23 +166,19 @@ def get_up_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, cross_attention_dim=cross_attention_dim, - attn_num_head_channels=attn_num_head_channels, + num_attention_heads=num_attention_heads, dual_cross_attention=dual_cross_attention, use_linear_projection=use_linear_projection, only_cross_attention=only_cross_attention, upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - use_inflated_groupnorm=use_inflated_groupnorm, - + attention_type=attention_type, use_motion_module=use_motion_module, motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) - raise ValueError(f"{up_block_type} does not exist.") + raise ValueError(f"{up_block_type} does not exist.") class UNetMidBlock3DCrossAttn(nn.Module): def __init__( @@ -175,36 +187,32 @@ class UNetMidBlock3DCrossAttn(nn.Module): temb_channels: int, dropout: float = 0.0, num_layers: int = 1, + transformer_layers_per_block: int = 1, resnet_eps: float = 1e-6, resnet_time_scale_shift: str = "default", resnet_act_fn: str = "swish", resnet_groups: int = 32, resnet_pre_norm: bool = True, - attn_num_head_channels=1, + num_attention_heads=1, output_scale_factor=1.0, cross_attention_dim=1280, dual_cross_attention=False, use_linear_projection=False, upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - + attention_type="default", use_motion_module=None, - motion_module_type=None, motion_module_kwargs=None, ): super().__init__() self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels + self.num_attention_heads = num_attention_heads resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) # there is always at least one resnet resnets = [ - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -215,40 +223,46 @@ class UNetMidBlock3DCrossAttn(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ] attentions = [] motion_modules = [] for _ in range(num_layers): - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - in_channels // attn_num_head_channels, - in_channels=in_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) ) - ) motion_modules.append( get_motion_module( in_channels=in_channels, - motion_module_type=motion_module_type, + motion_module_type=motion_module_type, motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, @@ -259,8 +273,6 @@ class UNetMidBlock3DCrossAttn(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) @@ -268,121 +280,22 @@ class UNetMidBlock3DCrossAttn(nn.Module): self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - hidden_states = self.resnets[0](hidden_states, temb) - for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - hidden_states = resnet(hidden_states, temb) - - return hidden_states - - -class CrossAttnDownBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - downsample_padding=1, - add_downsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - - use_motion_module=None, - - motion_module_type=None, - motion_module_kwargs=None, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - in_channels = in_channels if i == 0 else out_channels - resnets.append( - ResnetBlock3D( - in_channels=in_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, - ) - ) - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - ) - motion_modules.append( - get_motion_module( - in_channels=out_channels, - motion_module_type=motion_module_type, - motion_module_kwargs=motion_module_kwargs, - ) if use_motion_module else None - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_downsample: - self.downsamplers = nn.ModuleList( - [ - Downsample3D( - out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" - ) - ] - ) - else: - self.downsamplers = None - self.gradient_checkpointing = False - def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): - output_states = () - - for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + for attn, motion_module, resnet in zip(self.attentions, self.motion_modules, self.resnets[1:]): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): @@ -393,33 +306,46 @@ class CrossAttnDownBlock3D(nn.Module): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), + + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( hidden_states, - encoder_hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, )[0] - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) - + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, temb, + encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - # add motion module + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - output_states += (hidden_states,) - - if self.downsamplers is not None: - for downsampler in self.downsamplers: - hidden_states = downsampler(hidden_states) - - output_states += (hidden_states,) - - return hidden_states, output_states - + return hidden_states class DownBlock3D(nn.Module): def __init__( @@ -437,9 +363,6 @@ class DownBlock3D(nn.Module): output_scale_factor=1.0, add_downsample=True, downsample_padding=1, - - use_inflated_groupnorm=None, - use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, @@ -451,7 +374,7 @@ class DownBlock3D(nn.Module): for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=in_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -462,8 +385,6 @@ class DownBlock3D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( @@ -473,14 +394,14 @@ class DownBlock3D(nn.Module): motion_module_kwargs=motion_module_kwargs, ) if use_motion_module else None ) - + self.resnets = nn.ModuleList(resnets) self.motion_modules = nn.ModuleList(motion_modules) if add_downsample: self.downsamplers = nn.ModuleList( [ - Downsample3D( + Downsample2D( out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" ) ] @@ -495,178 +416,45 @@ class DownBlock3D(nn.Module): for resnet, motion_module in zip(self.resnets, self.motion_modules): if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, temb, encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states) if motion_module is not None else hidden_states - # add motion module - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) if self.downsamplers is not None: for downsampler in self.downsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = downsampler(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - output_states += (hidden_states,) + output_states = output_states + (hidden_states,) return hidden_states, output_states -class CrossAttnUpBlock3D(nn.Module): - def __init__( - self, - in_channels: int, - out_channels: int, - prev_output_channel: int, - temb_channels: int, - dropout: float = 0.0, - num_layers: int = 1, - resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", - resnet_act_fn: str = "swish", - resnet_groups: int = 32, - resnet_pre_norm: bool = True, - attn_num_head_channels=1, - cross_attention_dim=1280, - output_scale_factor=1.0, - add_upsample=True, - dual_cross_attention=False, - use_linear_projection=False, - only_cross_attention=False, - upcast_attention=False, - - unet_use_cross_frame_attention=None, - unet_use_temporal_attention=None, - use_inflated_groupnorm=None, - - use_motion_module=None, - - motion_module_type=None, - motion_module_kwargs=None, - ): - super().__init__() - resnets = [] - attentions = [] - motion_modules = [] - - self.has_cross_attention = True - self.attn_num_head_channels = attn_num_head_channels - - for i in range(num_layers): - res_skip_channels = in_channels if (i == num_layers - 1) else out_channels - resnet_in_channels = prev_output_channel if i == 0 else out_channels - - resnets.append( - ResnetBlock3D( - in_channels=resnet_in_channels + res_skip_channels, - out_channels=out_channels, - temb_channels=temb_channels, - eps=resnet_eps, - groups=resnet_groups, - dropout=dropout, - time_embedding_norm=resnet_time_scale_shift, - non_linearity=resnet_act_fn, - output_scale_factor=output_scale_factor, - pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, - ) - ) - if dual_cross_attention: - raise NotImplementedError - attentions.append( - Transformer3DModel( - attn_num_head_channels, - out_channels // attn_num_head_channels, - in_channels=out_channels, - num_layers=1, - cross_attention_dim=cross_attention_dim, - norm_num_groups=resnet_groups, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention, - upcast_attention=upcast_attention, - - unet_use_cross_frame_attention=unet_use_cross_frame_attention, - unet_use_temporal_attention=unet_use_temporal_attention, - ) - ) - motion_modules.append( - get_motion_module( - in_channels=out_channels, - motion_module_type=motion_module_type, - motion_module_kwargs=motion_module_kwargs, - ) if use_motion_module else None - ) - - self.attentions = nn.ModuleList(attentions) - self.resnets = nn.ModuleList(resnets) - self.motion_modules = nn.ModuleList(motion_modules) - - if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) - else: - self.upsamplers = None - - self.gradient_checkpointing = False - - def forward( - self, - hidden_states, - res_hidden_states_tuple, - temb=None, - encoder_hidden_states=None, - upsample_size=None, - attention_mask=None, - ): - for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): - # pop res hidden states - res_hidden_states = res_hidden_states_tuple[-1] - res_hidden_states_tuple = res_hidden_states_tuple[:-1] - hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - - if self.training and self.gradient_checkpointing: - - def create_custom_forward(module, return_dict=None): - def custom_forward(*inputs): - if return_dict is not None: - return module(*inputs, return_dict=return_dict) - else: - return module(*inputs) - - return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(attn, return_dict=False), - hidden_states, - encoder_hidden_states, - )[0] - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) - - else: - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample - - # add motion module - hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states - - if self.upsamplers is not None: - for upsampler in self.upsamplers: - hidden_states = upsampler(hidden_states, upsample_size) - - return hidden_states - - class UpBlock3D(nn.Module): def __init__( self, @@ -683,9 +471,6 @@ class UpBlock3D(nn.Module): resnet_pre_norm: bool = True, output_scale_factor=1.0, add_upsample=True, - - use_inflated_groupnorm=None, - use_motion_module=None, motion_module_type=None, motion_module_kwargs=None, @@ -699,7 +484,7 @@ class UpBlock3D(nn.Module): resnet_in_channels = prev_output_channel if i == 0 else out_channels resnets.append( - ResnetBlock3D( + ResnetBlock2D( in_channels=resnet_in_channels + res_skip_channels, out_channels=out_channels, temb_channels=temb_channels, @@ -710,8 +495,6 @@ class UpBlock3D(nn.Module): non_linearity=resnet_act_fn, output_scale_factor=output_scale_factor, pre_norm=resnet_pre_norm, - - use_inflated_groupnorm=use_inflated_groupnorm, ) ) motion_modules.append( @@ -726,35 +509,405 @@ class UpBlock3D(nn.Module): self.motion_modules = nn.ModuleList(motion_modules) if add_upsample: - self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) else: self.upsamplers = None self.gradient_checkpointing = False - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,): - for resnet, motion_module in zip(self.resnets, self.motion_modules): + def forward(self, hidden_states, res_hidden_states_tuple, encoder_hidden_states=None, temb=None, upsample_size=None): + for (resnet, motion_module) in zip(self.resnets, self.motion_modules): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) if self.training and self.gradient_checkpointing: + def create_custom_forward(module): def custom_forward(*inputs): return module(*inputs) return custom_forward - - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) - if motion_module is not None: - hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states) + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), hidden_states, + temb, encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = resnet(hidden_states, temb) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals=None, + ): + output_states = () + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, temb, + encoder_hidden_states, use_reentrant=False) if motion_module is not None else hidden_states + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = downsampler(hidden_states) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + attention_type="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states, temb, encoder_hidden_states, + use_reentrant=False) if motion_module is not None else hidden_states + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + hidden_states = resnet(hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states) if motion_module is not None else hidden_states if self.upsamplers is not None: for upsampler in self.upsamplers: + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") hidden_states = upsampler(hidden_states, upsample_size) + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) - return hidden_states + return hidden_states \ No newline at end of file diff --git a/animatediff/pipelines/pipeline_animation.py b/animatediff/pipelines/pipeline_animation.py index 58f22d1..d07794c 100644 --- a/animatediff/pipelines/pipeline_animation.py +++ b/animatediff/pipelines/pipeline_animation.py @@ -1,200 +1,361 @@ -# Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import inspect -from typing import Callable, List, Optional, Union +import os +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +import numpy as np +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer from dataclasses import dataclass -import numpy as np -import torch -from tqdm import tqdm - -from diffusers.utils import is_accelerate_available -from packaging import version -from transformers import CLIPTextModel, CLIPTokenizer - -from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL -from diffusers.pipeline_utils import DiffusionPipeline -from diffusers.schedulers import ( - DDIMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, +from animatediff.models.unet import UNet3DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, + BaseOutput, ) -from diffusers.utils import deprecate, logging, BaseOutput - from einops import rearrange +from diffusers.pipelines.pipeline_utils import DiffusionPipeline + + +@dataclass +class AnimatePipelineOutput(BaseOutput): + """ + Output class for Stable Diffusion pipelines. + + Args: + images (`List[PIL.Image.Image]` or `np.ndarray`) + List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width, + num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline. + """ + + videos: Union[torch.Tensor, np.ndarray] -from ..models.unet import UNet3DConditionModel logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLPipeline -@dataclass -class AnimationPipelineOutput(BaseOutput): - videos: Union[torch.Tensor, np.ndarray] + >>> pipe = StableDiffusionXLPipeline.from_pretrained( + ... "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16 + ... ) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" -class AnimationPipeline(DiffusionPipeline): - _optional_components = [] +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +class AnimationPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *LoRA*: [`StableDiffusionXLPipeline.load_lora_weights`] + - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`] + + as well as the following saving methods: + - *LoRA*: [`loaders.StableDiffusionXLPipeline.save_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion XL uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ def __init__( self, vae: AutoencoderKL, text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, unet: UNet3DConditionModel, - scheduler: Union[ - DDIMScheduler, - PNDMScheduler, - LMSDiscreteScheduler, - EulerDiscreteScheduler, - EulerAncestralDiscreteScheduler, - DPMSolverMultistepScheduler, - ], + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + add_watermarker: Optional[bool] = None, ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - self.register_modules( vae=vae, text_encoder=text_encoder, + text_encoder_2=text_encoder_2, tokenizer=tokenizer, + tokenizer_2=tokenizer_2, unet=unet, scheduler=scheduler, ) + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.default_sample_size = self.unet.config.sample_size + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ self.vae.enable_slicing() + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ self.vae.disable_slicing() - def enable_sequential_cpu_offload(self, gpu_id=0): - if is_accelerate_available(): - from accelerate import cpu_offload + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow + processing larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook else: - raise ImportError("Please install accelerate via `pip install accelerate`") + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") device = torch.device(f"cuda:{gpu_id}") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: - if cpu_offloaded_model is not None: - cpu_offload(cpu_offloaded_model, device) + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) - - @property - def _execution_device(self): - if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"): - return self.device - for module in self.unet.modules(): - if ( - hasattr(module, "_hf_hook") - and hasattr(module._hf_hook, "execution_device") - and module._hf_hook.execution_device is not None - ): - return torch.device(module._hf_hook.execution_device) - return self.device - - def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt): - batch_size = len(prompt) if isinstance(prompt, list) else 1 - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", + model_sequence = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + model_sequence.extend([self.unet, self.vae]) - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): - removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) + hook = None + for cpu_offloaded_model in model_sequence: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) + # We'll offload the last model manually. + self.final_offload_hook = hook + + def encode_prompt( + self, + prompt: str, + prompt_2: Optional[str] = None, + device: Optional[torch.device] = None, + num_videos_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_videos_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) else: - attention_mask = None + batch_size = prompt_embeds.shape[0] - text_embeddings = self.text_encoder( - text_input_ids.to(device), - attention_mask=attention_mask, + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] ) - text_embeddings = text_embeddings[0] - # duplicate text embeddings for each generation per prompt, using mps friendly method - bs_embed, seq_len, _ = text_embeddings.shape - text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1) - text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1) + if prompt_embeds is None: + prompt_2 = prompt_2 or prompt + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance: + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif type(prompt) is not type(negative_prompt): + if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -202,55 +363,58 @@ class AnimationPipeline(DiffusionPipeline): " the batch size of `prompt`." ) else: - uncond_tokens = negative_prompt + uncond_tokens = [negative_prompt, negative_prompt_2] - max_length = text_input_ids.shape[-1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) + negative_prompt_embeds_list = [] + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) - uncond_embeddings = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - uncond_embeddings = uncond_embeddings[0] + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_videos_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = uncond_embeddings.shape[1] - uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1) - uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1) + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_videos_per_prompt).view( + bs_embed * num_videos_per_prompt, -1 + ) - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - text_embeddings = torch.cat([uncond_embeddings, text_embeddings]) - - return text_embeddings - - def decode_latents(self, latents): - video_length = latents.shape[2] - latents = 1 / 0.18215 * latents - latents = rearrange(latents, "b c f h w -> (b f) c h w") - # video = self.vae.decode(latents).sample - video = [] - for frame_idx in tqdm(range(latents.shape[0])): - video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample) - video = torch.cat(video) - video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length) - video = (video / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 - video = video.cpu().float().numpy() - return video + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. @@ -268,10 +432,20 @@ class AnimationPipeline(DiffusionPipeline): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def check_inputs(self, prompt, height, width, callback_steps): - if not isinstance(prompt, str) and not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - + def check_inputs( + self, + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt=None, + negative_prompt_2=None, + prompt_embeds=None, + negative_prompt_embeds=None, + pooled_prompt_embeds=None, + negative_pooled_prompt_embeds=None, + ): if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -283,131 +457,393 @@ class AnimationPipeline(DiffusionPipeline): f" {type(callback_steps)}." ) - def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt_2 is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): + raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + elif negative_prompt_2 is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if prompt_embeds is not None and pooled_prompt_embeds is None: + raise ValueError( + "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." + ) + + if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None: + raise ValueError( + "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, single_model_length, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, single_model_length, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - if latents is None: - rand_device = "cpu" if device.type == "mps" else device - if isinstance(generator, list): - shape = shape - # shape = (1,) + shape[1:] - latents = [ - torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype) - for i in range(batch_size) - ] - latents = torch.cat(latents, dim=0).to(device) - else: - latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) else: - if latents.shape != shape: - raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, - prompt: Union[str, List[str]], - video_length: Optional[int], + prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, + single_model_length: Optional[int] = 16, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, - guidance_scale: float = 7.5, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, num_videos_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "tensor", + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, - callback_steps: Optional[int] = 1, - **kwargs, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + original_size: Optional[Tuple[int, int]] = None, + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Optional[Tuple[int, int]] = None, ): - # Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor + r""" + Function invoked when calling the pipeline for generation. - # Check inputs. Raise error if not correct - self.check_inputs(prompt, height, width, callback_steps) + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + num_videos_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead + of a plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.7): + Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of + [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). + Guidance rescale factor should fix overexposure when using zero terminal SNR. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). - # Define call parameters - # batch_size = 1 if isinstance(prompt, str) else len(prompt) - batch_size = 1 - if latents is not None: - batch_size = latents.shape[0] - if isinstance(prompt, list): + Examples: + + Returns: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. + """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + prompt_2, + height, + width, + callback_steps, + negative_prompt, + negative_prompt_2, + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` # corresponds to doing no classifier free guidance. do_classifier_free_guidance = guidance_scale > 1.0 - # Encode input prompt - prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size - if negative_prompt is not None: - negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size - text_embeddings = self._encode_prompt( - prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_videos_per_prompt=num_videos_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + lora_scale=text_encoder_lora_scale, ) - # Prepare timesteps + # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps - # Prepare latent variables - num_channels_latents = self.unet.in_channels + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( batch_size * num_videos_per_prompt, + single_model_length, num_channels_latents, - video_length, height, width, - text_embeddings.dtype, + prompt_embeds.dtype, device, generator, latents, ) - latents_dtype = latents.dtype - # Prepare extra step kwargs. + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - # Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + # 7. Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device) + + # 8. Denoising loop + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + + # 7.1 Apply denoising_end + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype) - # noise_pred = [] - # import pdb - # pdb.set_trace() - # for batch_idx in range(latent_model_input.shape[0]): - # noise_pred_single = self.unet(latent_model_input[batch_idx:batch_idx+1], t, encoder_hidden_states=text_embeddings[batch_idx:batch_idx+1]).sample.to(dtype=latents_dtype) - # noise_pred.append(noise_pred_single) - # noise_pred = torch.cat(noise_pred) + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + ts = torch.tensor([t], dtype=latent_model_input.dtype, device=latent_model_input.device) + if do_classifier_free_guidance: + ts = ts.repeat(2) + + noise_pred = self.unet( + latent_model_input, + ts, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): @@ -415,14 +851,98 @@ class AnimationPipeline(DiffusionPipeline): if callback is not None and i % callback_steps == 0: callback(i, t, latents) - # Post-processing - video = self.decode_latents(latents) + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float32 and latents.dtype == torch.float16: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) - # Convert to tensor - if output_type == "tensor": - video = torch.from_numpy(video) + if not output_type == "latent": + latents = rearrange(latents, "b c f h w -> (b f) c h w") + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + #image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + image = ((image + 1) / 2).clamp(0, 1) + video = rearrange(image, "(b f) c h w -> b c f h w", f=single_model_length).cpu() if not return_dict: - return video + return (video,) - return AnimationPipelineOutput(videos=video) + return AnimatePipelineOutput(videos=video) + + + # Overrride to properly handle the loading and unloading of the additional text encoder. + def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs): + # We could have accessed the unet config from `lora_state_dict()` too. We pass + # it here explicitly to be able to tell that it's coming from an SDXL + # pipeline. + + state_dict, network_alphas = self.lora_state_dict( + pretrained_model_name_or_path_or_dict, + unet_config=self.unet.config, + **kwargs, + ) + self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet) + + text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k} + if len(text_encoder_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder, + prefix="text_encoder", + lora_scale=self.lora_scale, + ) + + text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k} + if len(text_encoder_2_state_dict) > 0: + self.load_lora_into_text_encoder( + text_encoder_2_state_dict, + network_alphas=network_alphas, + text_encoder=self.text_encoder_2, + prefix="text_encoder_2", + lora_scale=self.lora_scale, + ) + + @classmethod + def save_lora_weights( + self, + save_directory: Union[str, os.PathLike], + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, + is_main_process: bool = True, + weight_name: str = None, + save_function: Callable = None, + safe_serialization: bool = True, + ): + state_dict = {} + + def pack_weights(layers, prefix): + layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers + layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()} + return layers_state_dict + + state_dict.update(pack_weights(unet_lora_layers, "unet")) + + if text_encoder_lora_layers and text_encoder_2_lora_layers: + state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder")) + state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2")) + + self.write_lora_layers( + state_dict=state_dict, + save_directory=save_directory, + is_main_process=is_main_process, + weight_name=weight_name, + save_function=save_function, + safe_serialization=safe_serialization, + ) + + def _remove_text_encoder_monkey_patch(self): + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder) + self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2) \ No newline at end of file diff --git a/animatediff/utils/convert_from_ckpt.py b/animatediff/utils/convert_from_ckpt.py deleted file mode 100644 index 9ee269d..0000000 --- a/animatediff/utils/convert_from_ckpt.py +++ /dev/null @@ -1,959 +0,0 @@ -# coding=utf-8 -# Copyright 2023 The HuggingFace Inc. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" Conversion script for the Stable Diffusion checkpoints.""" - -import re -from io import BytesIO -from typing import Optional - -import requests -import torch -from transformers import ( - AutoFeatureExtractor, - BertTokenizerFast, - CLIPImageProcessor, - CLIPTextModel, - CLIPTextModelWithProjection, - CLIPTokenizer, - CLIPVisionConfig, - CLIPVisionModelWithProjection, -) - -from diffusers.models import ( - AutoencoderKL, - PriorTransformer, - UNet2DConditionModel, -) -from diffusers.schedulers import ( - DDIMScheduler, - DDPMScheduler, - DPMSolverMultistepScheduler, - EulerAncestralDiscreteScheduler, - EulerDiscreteScheduler, - HeunDiscreteScheduler, - LMSDiscreteScheduler, - PNDMScheduler, - UnCLIPScheduler, -) -from diffusers.utils.import_utils import BACKENDS_MAPPING - - -def shave_segments(path, n_shave_prefix_segments=1): - """ - Removes segments. Positive values shave the first segments, negative shave the last segments. - """ - if n_shave_prefix_segments >= 0: - return ".".join(path.split(".")[n_shave_prefix_segments:]) - else: - return ".".join(path.split(".")[:n_shave_prefix_segments]) - - -def renew_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item.replace("in_layers.0", "norm1") - new_item = new_item.replace("in_layers.2", "conv1") - - new_item = new_item.replace("out_layers.0", "norm2") - new_item = new_item.replace("out_layers.3", "conv2") - - new_item = new_item.replace("emb_layers.1", "time_emb_proj") - new_item = new_item.replace("skip_connection", "conv_shortcut") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside resnets to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("nin_shortcut", "conv_shortcut") - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - # new_item = new_item.replace('norm.weight', 'group_norm.weight') - # new_item = new_item.replace('norm.bias', 'group_norm.bias') - - # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') - # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') - - # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): - """ - Updates paths inside attentions to the new naming scheme (local renaming) - """ - mapping = [] - for old_item in old_list: - new_item = old_item - - new_item = new_item.replace("norm.weight", "group_norm.weight") - new_item = new_item.replace("norm.bias", "group_norm.bias") - - new_item = new_item.replace("q.weight", "query.weight") - new_item = new_item.replace("q.bias", "query.bias") - - new_item = new_item.replace("k.weight", "key.weight") - new_item = new_item.replace("k.bias", "key.bias") - - new_item = new_item.replace("v.weight", "value.weight") - new_item = new_item.replace("v.bias", "value.bias") - - new_item = new_item.replace("proj_out.weight", "proj_attn.weight") - new_item = new_item.replace("proj_out.bias", "proj_attn.bias") - - new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) - - mapping.append({"old": old_item, "new": new_item}) - - return mapping - - -def assign_to_checkpoint( - paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None -): - """ - This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits - attention layers, and takes into account additional replacements that may arise. - - Assigns the weights to the new checkpoint. - """ - assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." - - # Splits the attention layers into three variables. - if attention_paths_to_split is not None: - for path, path_map in attention_paths_to_split.items(): - old_tensor = old_checkpoint[path] - channels = old_tensor.shape[0] // 3 - - target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) - - num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 - - old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) - query, key, value = old_tensor.split(channels // num_heads, dim=1) - - checkpoint[path_map["query"]] = query.reshape(target_shape) - checkpoint[path_map["key"]] = key.reshape(target_shape) - checkpoint[path_map["value"]] = value.reshape(target_shape) - - for path in paths: - new_path = path["new"] - - # These have already been assigned - if attention_paths_to_split is not None and new_path in attention_paths_to_split: - continue - - # Global renaming happens here - new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") - new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") - new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") - - if additional_replacements is not None: - for replacement in additional_replacements: - new_path = new_path.replace(replacement["old"], replacement["new"]) - - # proj_attn.weight has to be converted from conv 1D to linear - if "proj_attn.weight" in new_path: - checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0] - else: - checkpoint[new_path] = old_checkpoint[path["old"]] - - -def conv_attn_to_linear(checkpoint): - keys = list(checkpoint.keys()) - attn_keys = ["query.weight", "key.weight", "value.weight"] - for key in keys: - if ".".join(key.split(".")[-2:]) in attn_keys: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0, 0] - elif "proj_attn.weight" in key: - if checkpoint[key].ndim > 2: - checkpoint[key] = checkpoint[key][:, :, 0] - - -def create_unet_diffusers_config(original_config, image_size: int, controlnet=False): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - if controlnet: - unet_params = original_config.model.params.control_stage_config.params - else: - unet_params = original_config.model.params.unet_config.params - - vae_params = original_config.model.params.first_stage_config.params.ddconfig - - block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult] - - down_block_types = [] - resolution = 1 - for i in range(len(block_out_channels)): - block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D" - down_block_types.append(block_type) - if i != len(block_out_channels) - 1: - resolution *= 2 - - up_block_types = [] - for i in range(len(block_out_channels)): - block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D" - up_block_types.append(block_type) - resolution //= 2 - - vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1) - - head_dim = unet_params.num_heads if "num_heads" in unet_params else None - use_linear_projection = ( - unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False - ) - if use_linear_projection: - # stable diffusion 2-base-512 and 2-768 - if head_dim is None: - head_dim = [5, 10, 20, 20] - - class_embed_type = None - projection_class_embeddings_input_dim = None - - if "num_classes" in unet_params: - if unet_params.num_classes == "sequential": - class_embed_type = "projection" - assert "adm_in_channels" in unet_params - projection_class_embeddings_input_dim = unet_params.adm_in_channels - else: - raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}") - - config = { - "sample_size": image_size // vae_scale_factor, - "in_channels": unet_params.in_channels, - "down_block_types": tuple(down_block_types), - "block_out_channels": tuple(block_out_channels), - "layers_per_block": unet_params.num_res_blocks, - "cross_attention_dim": unet_params.context_dim, - "attention_head_dim": head_dim, - "use_linear_projection": use_linear_projection, - "class_embed_type": class_embed_type, - "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - } - - if not controlnet: - config["out_channels"] = unet_params.out_channels - config["up_block_types"] = tuple(up_block_types) - - return config - - -def create_vae_diffusers_config(original_config, image_size: int): - """ - Creates a config for the diffusers based on the config of the LDM model. - """ - vae_params = original_config.model.params.first_stage_config.params.ddconfig - _ = original_config.model.params.first_stage_config.params.embed_dim - - block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult] - down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) - up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) - - config = { - "sample_size": image_size, - "in_channels": vae_params.in_channels, - "out_channels": vae_params.out_ch, - "down_block_types": tuple(down_block_types), - "up_block_types": tuple(up_block_types), - "block_out_channels": tuple(block_out_channels), - "latent_channels": vae_params.z_channels, - "layers_per_block": vae_params.num_res_blocks, - } - return config - - -def create_diffusers_schedular(original_config): - schedular = DDIMScheduler( - num_train_timesteps=original_config.model.params.timesteps, - beta_start=original_config.model.params.linear_start, - beta_end=original_config.model.params.linear_end, - beta_schedule="scaled_linear", - ) - return schedular - - -def create_ldm_bert_config(original_config): - bert_params = original_config.model.parms.cond_stage_config.params - config = LDMBertConfig( - d_model=bert_params.n_embed, - encoder_layers=bert_params.n_layer, - encoder_ffn_dim=bert_params.n_embed * 4, - ) - return config - - -def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False): - """ - Takes a state dict and a config, and returns a converted checkpoint. - """ - - # extract state_dict for UNet - unet_state_dict = {} - keys = list(checkpoint.keys()) - - if controlnet: - unet_key = "control_model." - else: - unet_key = "model.diffusion_model." - - # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA - if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema: - print(f"Checkpoint {path} has both EMA and non-EMA weights.") - print( - "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA" - " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag." - ) - for key in keys: - if key.startswith("model.diffusion_model"): - flat_ema_key = "model_ema." + "".join(key.split(".")[1:]) - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key) - else: - if sum(k.startswith("model_ema") for k in keys) > 100: - print( - "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA" - " weights (usually better for inference), please make sure to add the `--extract_ema` flag." - ) - - for key in keys: - if key.startswith(unet_key): - unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) - - new_checkpoint = {} - - new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] - new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] - new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] - new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] - - if config["class_embed_type"] is None: - # No parameters to port - ... - elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection": - new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"] - new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"] - new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"] - new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"] - else: - raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}") - - new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] - new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] - - if not controlnet: - new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] - new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] - new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] - new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] - - # Retrieves the keys for the input blocks only - num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) - input_blocks = { - layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key] - for layer_id in range(num_input_blocks) - } - - # Retrieves the keys for the middle blocks only - num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) - middle_blocks = { - layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key] - for layer_id in range(num_middle_blocks) - } - - # Retrieves the keys for the output blocks only - num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) - output_blocks = { - layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key] - for layer_id in range(num_output_blocks) - } - - for i in range(1, num_input_blocks): - block_id = (i - 1) // (config["layers_per_block"] + 1) - layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) - - resnets = [ - key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key - ] - attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] - - if f"input_blocks.{i}.0.op.weight" in unet_state_dict: - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.weight" - ) - new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop( - f"input_blocks.{i}.0.op.bias" - ) - - paths = renew_resnet_paths(resnets) - meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - resnet_0 = middle_blocks[0] - attentions = middle_blocks[1] - resnet_1 = middle_blocks[2] - - resnet_0_paths = renew_resnet_paths(resnet_0) - assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) - - resnet_1_paths = renew_resnet_paths(resnet_1) - assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) - - attentions_paths = renew_attention_paths(attentions) - meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} - assign_to_checkpoint( - attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - for i in range(num_output_blocks): - block_id = i // (config["layers_per_block"] + 1) - layer_in_block_id = i % (config["layers_per_block"] + 1) - output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] - output_block_list = {} - - for layer in output_block_layers: - layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) - if layer_id in output_block_list: - output_block_list[layer_id].append(layer_name) - else: - output_block_list[layer_id] = [layer_name] - - if len(output_block_list) > 1: - resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] - attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] - - resnet_0_paths = renew_resnet_paths(resnets) - paths = renew_resnet_paths(resnets) - - meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - - output_block_list = {k: sorted(v) for k, v in output_block_list.items()} - if ["conv.bias", "conv.weight"] in output_block_list.values(): - index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.weight" - ] - new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ - f"output_blocks.{i}.{index}.conv.bias" - ] - - # Clear attentions as they have been attributed above. - if len(attentions) == 2: - attentions = [] - - if len(attentions): - paths = renew_attention_paths(attentions) - meta_path = { - "old": f"output_blocks.{i}.1", - "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", - } - assign_to_checkpoint( - paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config - ) - else: - resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) - for path in resnet_0_paths: - old_path = ".".join(["output_blocks", str(i), path["old"]]) - new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) - - new_checkpoint[new_path] = unet_state_dict[old_path] - - if controlnet: - # conditioning embedding - - orig_index = 0 - - new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - orig_index += 2 - - diffusers_index = 0 - - while diffusers_index < 6: - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - diffusers_index += 1 - orig_index += 2 - - new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.weight" - ) - new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop( - f"input_hint_block.{orig_index}.bias" - ) - - # down blocks - for i in range(num_input_blocks): - new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight") - new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias") - - # mid block - new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight") - new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias") - - return new_checkpoint - - -def convert_ldm_vae_checkpoint(checkpoint, config): - # extract state dict for VAE - vae_state_dict = {} - vae_key = "first_stage_model." - keys = list(checkpoint.keys()) - for key in keys: - if key.startswith(vae_key): - vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) - - new_checkpoint = {} - - new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] - new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] - new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] - new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] - new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] - new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] - - new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] - new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] - new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] - new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] - new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] - new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] - - new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] - new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] - new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] - new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] - - # Retrieves the keys for the encoder down blocks only - num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) - down_blocks = { - layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) - } - - # Retrieves the keys for the decoder up blocks only - num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) - up_blocks = { - layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) - } - - for i in range(num_down_blocks): - resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] - - if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.weight" - ) - new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( - f"encoder.down.{i}.downsample.conv.bias" - ) - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - - for i in range(num_up_blocks): - block_id = num_up_blocks - 1 - i - resnets = [ - key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key - ] - - if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.weight" - ] - new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ - f"decoder.up.{block_id}.upsample.conv.bias" - ] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] - num_mid_res_blocks = 2 - for i in range(1, num_mid_res_blocks + 1): - resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] - - paths = renew_vae_resnet_paths(resnets) - meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - - mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] - paths = renew_vae_attention_paths(mid_attentions) - meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} - assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) - conv_attn_to_linear(new_checkpoint) - return new_checkpoint - - -def convert_ldm_bert_checkpoint(checkpoint, config): - def _copy_attn_layer(hf_attn_layer, pt_attn_layer): - hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight - hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight - hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight - - hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight - hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias - - def _copy_linear(hf_linear, pt_linear): - hf_linear.weight = pt_linear.weight - hf_linear.bias = pt_linear.bias - - def _copy_layer(hf_layer, pt_layer): - # copy layer norms - _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0]) - _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0]) - - # copy attn - _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1]) - - # copy MLP - pt_mlp = pt_layer[1][1] - _copy_linear(hf_layer.fc1, pt_mlp.net[0][0]) - _copy_linear(hf_layer.fc2, pt_mlp.net[2]) - - def _copy_layers(hf_layers, pt_layers): - for i, hf_layer in enumerate(hf_layers): - if i != 0: - i += i - pt_layer = pt_layers[i : i + 2] - _copy_layer(hf_layer, pt_layer) - - hf_model = LDMBertModel(config).eval() - - # copy embeds - hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight - hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight - - # copy layer norm - _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm) - - # copy hidden layers - _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers) - - _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits) - - return hf_model - - -def convert_ldm_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14") - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -textenc_conversion_lst = [ - ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"), - ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"), - ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"), - ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"), -] -textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst} - -textenc_transformer_conversion_lst = [ - # (stable-diffusion, HF Diffusers) - ("resblocks.", "text_model.encoder.layers."), - ("ln_1", "layer_norm1"), - ("ln_2", "layer_norm2"), - (".c_fc.", ".fc1."), - (".c_proj.", ".fc2."), - (".attn", ".self_attn"), - ("ln_final.", "transformer.text_model.final_layer_norm."), - ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"), - ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"), -] -protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst} -textenc_pattern = re.compile("|".join(protected.keys())) - - -def convert_paint_by_example_checkpoint(checkpoint): - config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14") - model = PaintByExampleImageEncoder(config) - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - for key in keys: - if key.startswith("cond_stage_model.transformer"): - text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] - - # load clip vision - model.model.load_state_dict(text_model_dict) - - # load mapper - keys_mapper = { - k[len("cond_stage_model.mapper.res") :]: v - for k, v in checkpoint.items() - if k.startswith("cond_stage_model.mapper") - } - - MAPPING = { - "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"], - "attn.c_proj": ["attn1.to_out.0"], - "ln_1": ["norm1"], - "ln_2": ["norm3"], - "mlp.c_fc": ["ff.net.0.proj"], - "mlp.c_proj": ["ff.net.2"], - } - - mapped_weights = {} - for key, value in keys_mapper.items(): - prefix = key[: len("blocks.i")] - suffix = key.split(prefix)[-1].split(".")[-1] - name = key.split(prefix)[-1].split(suffix)[0][1:-1] - mapped_names = MAPPING[name] - - num_splits = len(mapped_names) - for i, mapped_name in enumerate(mapped_names): - new_name = ".".join([prefix, mapped_name, suffix]) - shape = value.shape[0] // num_splits - mapped_weights[new_name] = value[i * shape : (i + 1) * shape] - - model.mapper.load_state_dict(mapped_weights) - - # load final layer norm - model.final_layer_norm.load_state_dict( - { - "bias": checkpoint["cond_stage_model.final_ln.bias"], - "weight": checkpoint["cond_stage_model.final_ln.weight"], - } - ) - - # load final proj - model.proj_out.load_state_dict( - { - "bias": checkpoint["proj_out.bias"], - "weight": checkpoint["proj_out.weight"], - } - ) - - # load uncond vector - model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"]) - return model - - -def convert_open_clip_checkpoint(checkpoint): - text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder") - - keys = list(checkpoint.keys()) - - text_model_dict = {} - - if "cond_stage_model.model.text_projection" in checkpoint: - d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0]) - else: - d_model = 1024 - - text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids") - - for key in keys: - if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer - continue - if key in textenc_conversion_map: - text_model_dict[textenc_conversion_map[key]] = checkpoint[key] - if key.startswith("cond_stage_model.model.transformer."): - new_key = key[len("cond_stage_model.model.transformer.") :] - if new_key.endswith(".in_proj_weight"): - new_key = new_key[: -len(".in_proj_weight")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :] - text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :] - text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :] - elif new_key.endswith(".in_proj_bias"): - new_key = new_key[: -len(".in_proj_bias")] - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model] - text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2] - text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :] - else: - new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key) - - text_model_dict[new_key] = checkpoint[key] - - text_model.load_state_dict(text_model_dict) - - return text_model - - -def stable_unclip_image_encoder(original_config): - """ - Returns the image processor and clip image encoder for the img2img unclip pipeline. - - We currently know of two types of stable unclip models which separately use the clip and the openclip image - encoders. - """ - - image_embedder_config = original_config.model.params.embedder_config - - sd_clip_image_embedder_class = image_embedder_config.target - sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1] - - if sd_clip_image_embedder_class == "ClipImageEmbedder": - clip_model_name = image_embedder_config.params.model - - if clip_model_name == "ViT-L/14": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14") - else: - raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}") - - elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder": - feature_extractor = CLIPImageProcessor() - image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K") - else: - raise NotImplementedError( - f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}" - ) - - return feature_extractor, image_encoder - - -def stable_unclip_image_noising_components( - original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None -): - """ - Returns the noising components for the img2img and txt2img unclip pipelines. - - Converts the stability noise augmentor into - 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats - 2. a `DDPMScheduler` for holding the noise schedule - - If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided. - """ - noise_aug_config = original_config.model.params.noise_aug_config - noise_aug_class = noise_aug_config.target - noise_aug_class = noise_aug_class.split(".")[-1] - - if noise_aug_class == "CLIPEmbeddingNoiseAugmentation": - noise_aug_config = noise_aug_config.params - embedding_dim = noise_aug_config.timestep_dim - max_noise_level = noise_aug_config.noise_schedule_config.timesteps - beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule - - image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim) - image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule) - - if "clip_stats_path" in noise_aug_config: - if clip_stats_path is None: - raise ValueError("This stable unclip config requires a `clip_stats_path`") - - clip_mean, clip_std = torch.load(clip_stats_path, map_location=device) - clip_mean = clip_mean[None, :] - clip_std = clip_std[None, :] - - clip_stats_state_dict = { - "mean": clip_mean, - "std": clip_std, - } - - image_normalizer.load_state_dict(clip_stats_state_dict) - else: - raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}") - - return image_normalizer, image_noising_scheduler - - -def convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema -): - ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True) - ctrlnet_config["upcast_attention"] = upcast_attention - - ctrlnet_config.pop("sample_size") - - controlnet_model = ControlNetModel(**ctrlnet_config) - - converted_ctrl_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True - ) - - controlnet_model.load_state_dict(converted_ctrl_checkpoint) - - return controlnet_model diff --git a/animatediff/utils/convert_lora_safetensor_to_diffusers.py b/animatediff/utils/convert_lora_safetensor_to_diffusers.py index 7490e38..543f03b 100644 --- a/animatediff/utils/convert_lora_safetensor_to_diffusers.py +++ b/animatediff/utils/convert_lora_safetensor_to_diffusers.py @@ -5,7 +5,7 @@ # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, @@ -23,132 +23,112 @@ from safetensors.torch import load_file from diffusers import StableDiffusionPipeline import pdb - - -def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0): - # directly update weight in diffusers model - for key in state_dict: - # only process lora down key - if "up." in key: continue - - up_key = key.replace(".down.", ".up.") - model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "") - model_key = model_key.replace("to_out.", "to_out.0.") - layer_infos = model_key.split(".")[:-1] - - curr_layer = pipeline.unet - while len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - curr_layer = curr_layer.__getattr__(temp_name) - - weight_down = state_dict[key] - weight_up = state_dict[up_key] - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) - - return pipeline - - - def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6): - # load base model - # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) + # load base model + # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32) - # load LoRA weight from .safetensors - # state_dict = load_file(checkpoint_path) + # load LoRA weight from .safetensors + # state_dict = load_file(checkpoint_path) - visited = [] + visited = [] + # directly update weight in diffusers model + for lora_name in state_dict: + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" - # directly update weight in diffusers model - for key in state_dict: - # it is suggested to print out the key, it usually will be something like below - # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + # as we have set the alpha beforehand, so just skip + if ".alpha" in lora_name or lora_name in visited: + continue - # as we have set the alpha beforehand, so just skip - if ".alpha" in key or key in visited: - continue + if "te" in lora_name: + if "lora_te1" in key: + LORA_PREFIX_TEXT_ENCODER = "lora_te1" + elif "lora_te2" in key: + LORA_PREFIX_TEXT_ENCODER = "lora_te2" + else: + LORA_PREFIX_TEXT_ENCODER = "lora_te" + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + - if "text" in key: - layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") - curr_layer = pipeline.text_encoder - else: - layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") - curr_layer = pipeline.unet + else: + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet - # find the target layer - temp_name = layer_infos.pop(0) - while len(layer_infos) > -1: - try: - curr_layer = curr_layer.__getattr__(temp_name) - if len(layer_infos) > 0: - temp_name = layer_infos.pop(0) - elif len(layer_infos) == 0: - break - except Exception: - if len(temp_name) > 0: - temp_name += "_" + layer_infos.pop(0) - else: - temp_name = layer_infos.pop(0) + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) - pair_keys = [] - if "lora_down" in key: - pair_keys.append(key.replace("lora_down", "lora_up")) - pair_keys.append(key) - else: - pair_keys.append(key) - pair_keys.append(key.replace("lora_up", "lora_down")) + pair_keys = [] + if "lora.down" in key: + pair_keys.append(key.replace("lora.down", "lora.up")) + pair_keys.append(key) + else: + pair_keys.append(key) + pair_keys.append(key.replace("lora.up", "lora.down")) - # update weight - if len(state_dict[pair_keys[0]].shape) == 4: - weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) - weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) - else: - weight_up = state_dict[pair_keys[0]].to(torch.float32) - weight_down = state_dict[pair_keys[1]].to(torch.float32) - curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) + # update weight + if len(state_dict[pair_keys[0]].shape) == 4: + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device) + else: + weight_up = state_dict[pair_keys[0]].to(torch.float32) + weight_down = state_dict[pair_keys[1]].to(torch.float32) + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device) - # update visited list - for item in pair_keys: - visited.append(item) + # update visited list + for item in pair_keys: + visited.append(item) - return pipeline + return pipeline if __name__ == "__main__": - parser = argparse.ArgumentParser() + parser = argparse.ArgumentParser() - parser.add_argument( - "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." - ) - parser.add_argument( - "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." - ) - parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") - parser.add_argument( - "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" - ) - parser.add_argument( - "--lora_prefix_text_encoder", - default="lora_te", - type=str, - help="The prefix of text encoder weight in safetensors", - ) - parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") - parser.add_argument( - "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." - ) - parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") + parser.add_argument( + "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format." + ) + parser.add_argument( + "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." + ) + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") + parser.add_argument( + "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors" + ) + parser.add_argument( + "--lora_prefix_text_encoder", + default="lora_te", + type=str, + help="The prefix of text encoder weight in safetensors", + ) + parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW") + parser.add_argument( + "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not." + ) + parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)") - args = parser.parse_args() + args = parser.parse_args() - base_model_path = args.base_model_path - checkpoint_path = args.checkpoint_path - dump_path = args.dump_path - lora_prefix_unet = args.lora_prefix_unet - lora_prefix_text_encoder = args.lora_prefix_text_encoder - alpha = args.alpha + base_model_path = args.base_model_path + checkpoint_path = args.checkpoint_path + dump_path = args.dump_path + lora_prefix_unet = args.lora_prefix_unet + lora_prefix_text_encoder = args.lora_prefix_text_encoder + alpha = args.alpha - pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) + pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha) - pipe = pipe.to(args.device) - pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) + pipe = pipe.to(args.device) + pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors) diff --git a/animatediff/utils/util.py b/animatediff/utils/util.py index 5393385..1fced56 100644 --- a/animatediff/utils/util.py +++ b/animatediff/utils/util.py @@ -3,19 +3,23 @@ import imageio import numpy as np from typing import Union +import random import torch -import torchvision import torch.distributed as dist +import torchvision -from safetensors import safe_open +import diffusers from tqdm import tqdm from einops import rearrange -from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint -from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora, convert_motion_lora_ckpt_to_diffusers +from safetensors import safe_open +from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline + +from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora def zero_rank_print(s): - if (not dist.is_initialized()) and (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) + if not isinstance(s, str): s = repr(s) + if (not dist.is_initialized()) or (dist.is_initialized() and dist.get_rank() == 0): print("### " + s) def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8): @@ -91,67 +95,1714 @@ def ddim_inversion(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt ddim_latents = ddim_loop(pipeline, ddim_scheduler, video_latent, num_inv_steps, prompt) return ddim_latents -def load_weights( - animation_pipeline, - # motion module - motion_module_path = "", - motion_module_lora_configs = [], - # image layers - dreambooth_model_path = "", - lora_model_path = "", - lora_alpha = 0.8, -): - # 1.1 motion module - unet_state_dict = {} - if motion_module_path != "": - print(f"load motion module from {motion_module_path}") - motion_module_state_dict = torch.load(motion_module_path, map_location="cpu") - motion_module_state_dict = motion_module_state_dict["state_dict"] if "state_dict" in motion_module_state_dict else motion_module_state_dict - unet_state_dict.update({name: param for name, param in motion_module_state_dict.items() if "motion_modules." in name}) - - missing, unexpected = animation_pipeline.unet.load_state_dict(unet_state_dict, strict=False) - assert len(unexpected) == 0 - del unet_state_dict - if dreambooth_model_path != "": - print(f"load dreambooth model from {dreambooth_model_path}") - if dreambooth_model_path.endswith(".safetensors"): - dreambooth_state_dict = {} - with safe_open(dreambooth_model_path, framework="pt", device="cpu") as f: - for key in f.keys(): - dreambooth_state_dict[key] = f.get_tensor(key) - elif dreambooth_model_path.endswith(".ckpt"): - dreambooth_state_dict = torch.load(dreambooth_model_path, map_location="cpu") - - # 1. vae - converted_vae_checkpoint = convert_ldm_vae_checkpoint(dreambooth_state_dict, animation_pipeline.vae.config) - animation_pipeline.vae.load_state_dict(converted_vae_checkpoint) - # 2. unet - converted_unet_checkpoint = convert_ldm_unet_checkpoint(dreambooth_state_dict, animation_pipeline.unet.config) - animation_pipeline.unet.load_state_dict(converted_unet_checkpoint, strict=False) - # 3. text_model - animation_pipeline.text_encoder = convert_ldm_clip_checkpoint(dreambooth_state_dict) - del dreambooth_state_dict +def load_weights(pipeline, motion_module_path, ckpt_path, lora_path, lora_alpha): + # Load ckpt + + if ckpt_path != "": + ( + text_model1, + text_model2, + vae, + unet, + logit_scale, + ckpt_info + ) = load_models_from_sdxl_checkpoint(MODEL_VERSION_SDXL_BASE_V1_0, ckpt_path, 'cpu') + - if lora_model_path != "": - print(f"load lora model from {lora_model_path}") - assert lora_model_path.endswith(".safetensors") + unet_state_dict = unet.state_dict() + pipeline.unet.load_state_dict(unet_state_dict, strict=False) + pipeline.vae = vae + pipeline.text_encoder = text_model1 + pipeline.text_encoder_2 = text_model2 + del unet + del unet_state_dict + del vae + del text_model1 + del text_model2 + print(f'Loading ckpt model from {ckpt_path}') + + # Load Motion Module + if motion_module_path != "": + motion_module_ckpt = torch.load(motion_module_path, map_location='cpu') + + motion_module_state_dict = {} + m_k = None + for k, v in motion_module_ckpt.items(): + if 'motion_module' in k and k in pipeline.unet.state_dict().keys(): + motion_module_state_dict[k] = v + m_k = k + elif 'motion_module' in k and k not in pipeline.unet.state_dict().keys(): + print(k) + + pipeline.unet.load_state_dict(motion_module_state_dict, strict=False) + del motion_module_ckpt + del motion_module_state_dict + print(f'Loading motion module from {motion_module_path}...') + + # Load LoRA + if lora_path != "": lora_state_dict = {} - with safe_open(lora_model_path, framework="pt", device="cpu") as f: - for key in f.keys(): - lora_state_dict[key] = f.get_tensor(key) - - animation_pipeline = convert_lora(animation_pipeline, lora_state_dict, alpha=lora_alpha) - del lora_state_dict + with safe_open(lora_path, framework='pt', device='cpu') as f: + for k in f.keys(): + lora_state_dict[k] = f.get_tensor(k) + for k, v in lora_state_dict.items(): + if 'lora.up' in k: + + down_key = k.replace('lora.up', 'lora.down') + if 'to_out' not in k: + original_key = k.replace('processor.', '').replace('_lora.up', '') + else: + original_key = k.replace('processor.', '').replace('_lora.up', '.0') + pipeline.unet.state_dict()[original_key] += lora_alpha * torch.mm(v, lora_state_dict[down_key]) + print(f'Loading lora model from {lora_path}') + + return pipeline - for motion_module_lora_config in motion_module_lora_configs: - path, alpha = motion_module_lora_config["path"], motion_module_lora_config["alpha"] - print(f"load motion LoRA from {path}") - motion_lora_state_dict = torch.load(path, map_location="cpu") - motion_lora_state_dict = motion_lora_state_dict["state_dict"] if "state_dict" in motion_lora_state_dict else motion_lora_state_dict +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts=0, caption_column='text', is_train=True): + prompt_embeds_list = [] + prompt_batch = batch[caption_column] - animation_pipeline = convert_motion_lora_ckpt_to_diffusers(animation_pipeline, motion_lora_state_dict, alpha) + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return {"prompt_embeds": prompt_embeds.cpu(), "pooled_prompt_embeds": pooled_prompt_embeds.cpu()} + + +MODEL_VERSION_SDXL_BASE_V1_0 = "sdxl_base_v1-0" +from safetensors.torch import load_file +from accelerate.utils.modeling import set_module_tensor_to_device +from animatediff.utils.xl_lora_util import SdxlUNet2DConditionModel +from transformers import CLIPTextModel, CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer +from packaging import version +from accelerate import init_empty_weights +import transformers +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + +def load_models_from_sdxl_checkpoint(model_version, ckpt_path, map_location, dtype=None, unet_only=False): + # model_version is reserved for future use + # dtype is reserved for full_fp16/bf16 integration. Text Encoder will remain fp32, because it runs on CPU when caching + + # Load the state dict + if is_safetensors(ckpt_path): + checkpoint = None + try: + state_dict = load_file(ckpt_path, device=map_location) + except: + state_dict = load_file(ckpt_path) # prevent device invalid Error + epoch = None + global_step = None + else: + checkpoint = torch.load(ckpt_path, map_location=map_location) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + epoch = checkpoint.get("epoch", 0) + global_step = checkpoint.get("global_step", 0) + else: + state_dict = checkpoint + epoch = 0 + global_step = 0 + checkpoint = None + + # U-Net + print("building U-Net") + with init_empty_weights(): + unet = SdxlUNet2DConditionModel() + + print("loading U-Net from checkpoint") + unet_sd = {} + for k in list(state_dict.keys()): + if k.startswith("model.diffusion_model."): + unet_sd[k.replace("model.diffusion_model.", "")] = state_dict.pop(k) + info = _load_state_dict_on_device(unet, unet_sd, device=map_location) + print("U-Net: ", info) + unet_sd= unet.state_dict() + du_unet_sd = convert_sdxl_unet_state_dict_to_diffusers(unet_sd) + from animatediff.models.unet import UNet3DConditionModel + diffusers_unet = UNet3DConditionModel(**DIFFUSERS_SDXL_UNET_CONFIG) + diffusers_unet.load_state_dict(du_unet_sd) + + if unet_only: + return None, None, None, diffusers_unet, None, None + + # Text Encoders + print("building text encoders") + + # Text Encoder 1 is same to Stability AI's SDXL + text_model1_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=768, + intermediate_size=3072, + num_hidden_layers=12, + num_attention_heads=12, + max_position_embeddings=77, + hidden_act="quick_gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=768, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model1 = CLIPTextModel._from_config(text_model1_cfg) + + # Text Encoder 2 is different from Stability AI's SDXL. SDXL uses open clip, but we use the model from HuggingFace. + # Note: Tokenizer from HuggingFace is different from SDXL. We must use open clip's tokenizer. + text_model2_cfg = CLIPTextConfig( + vocab_size=49408, + hidden_size=1280, + intermediate_size=5120, + num_hidden_layers=32, + num_attention_heads=20, + max_position_embeddings=77, + hidden_act="gelu", + layer_norm_eps=1e-05, + dropout=0.0, + attention_dropout=0.0, + initializer_range=0.02, + initializer_factor=1.0, + pad_token_id=1, + bos_token_id=0, + eos_token_id=2, + model_type="clip_text_model", + projection_dim=1280, + # torch_dtype="float32", + # transformers_version="4.25.0.dev0", + ) + text_model2 = CLIPTextModelWithProjection(text_model2_cfg) + + print("loading text encoders from checkpoint") + te1_sd = {} + te2_sd = {} + for k in list(state_dict.keys()): + if k.startswith("conditioner.embedders.0.transformer."): + te1_sd[k.replace("conditioner.embedders.0.transformer.", "")] = state_dict.pop(k) + elif k.startswith("conditioner.embedders.1.model."): + te2_sd[k] = state_dict.pop(k) + + if version.parse(transformers.__version__) > version.parse("4.31"): + # After transformers==4.31, position_ids becomes a persistent=False buffer (so we musn't supply it) + # https://github.com/huggingface/transformers/pull/24505 + # https://github.com/mlfoundations/open_clip/pull/595 + if 'text_model.embeddings.position_ids' in te1_sd: + del te1_sd['text_model.embeddings.position_ids'] + info1 = text_model1.load_state_dict(te1_sd) + print("text encoder 1:", info1) + + converted_sd, logit_scale = convert_sdxl_text_encoder_2_checkpoint(te2_sd, max_length=77) + info2 = text_model2.load_state_dict(converted_sd) + print("text encoder 2:", info2) + + # prepare vae + print("building VAE") + vae_config = create_vae_diffusers_config() + vae = AutoencoderKL(**vae_config) # .to(device) + + print("loading VAE from checkpoint") + converted_vae_checkpoint = convert_ldm_vae_checkpoint(state_dict, vae_config) + info = vae.load_state_dict(converted_vae_checkpoint) + print("VAE:", info) + + ckpt_info = (epoch, global_step) if epoch is not None else None + return text_model1, text_model2, vae, diffusers_unet, logit_scale, ckpt_info + + + # load state_dict without allocating new tensors + + +def shave_segments(path, n_shave_prefix_segments=1): + """ + Removes segments. Positive values shave the first segments, negative shave the last segments. + """ + if n_shave_prefix_segments >= 0: + return ".".join(path.split(".")[n_shave_prefix_segments:]) + else: + return ".".join(path.split(".")[:n_shave_prefix_segments]) + + +def renew_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item.replace("in_layers.0", "norm1") + new_item = new_item.replace("in_layers.2", "conv1") + + new_item = new_item.replace("out_layers.0", "norm2") + new_item = new_item.replace("out_layers.3", "conv2") + + new_item = new_item.replace("emb_layers.1", "time_emb_proj") + new_item = new_item.replace("skip_connection", "conv_shortcut") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside resnets to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("nin_shortcut", "conv_shortcut") + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + # new_item = new_item.replace('norm.weight', 'group_norm.weight') + # new_item = new_item.replace('norm.bias', 'group_norm.bias') + + # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight') + # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias') + + # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0): + """ + Updates paths inside attentions to the new naming scheme (local renaming) + """ + mapping = [] + for old_item in old_list: + new_item = old_item + + new_item = new_item.replace("norm.weight", "group_norm.weight") + new_item = new_item.replace("norm.bias", "group_norm.bias") + + if diffusers.__version__ < "0.17.0": + new_item = new_item.replace("q.weight", "query.weight") + new_item = new_item.replace("q.bias", "query.bias") + + new_item = new_item.replace("k.weight", "key.weight") + new_item = new_item.replace("k.bias", "key.bias") + + new_item = new_item.replace("v.weight", "value.weight") + new_item = new_item.replace("v.bias", "value.bias") + + new_item = new_item.replace("proj_out.weight", "proj_attn.weight") + new_item = new_item.replace("proj_out.bias", "proj_attn.bias") + else: + new_item = new_item.replace("q.weight", "to_q.weight") + new_item = new_item.replace("q.bias", "to_q.bias") + + new_item = new_item.replace("k.weight", "to_k.weight") + new_item = new_item.replace("k.bias", "to_k.bias") + + new_item = new_item.replace("v.weight", "to_v.weight") + new_item = new_item.replace("v.bias", "to_v.bias") + + new_item = new_item.replace("proj_out.weight", "to_out.0.weight") + new_item = new_item.replace("proj_out.bias", "to_out.0.bias") + + new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments) + + mapping.append({"old": old_item, "new": new_item}) + + return mapping + + +def assign_to_checkpoint( + paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None +): + """ + This does the final conversion step: take locally converted weights and apply a global renaming + to them. It splits attention layers, and takes into account additional replacements + that may arise. + + Assigns the weights to the new checkpoint. + """ + assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys." + + # Splits the attention layers into three variables. + if attention_paths_to_split is not None: + for path, path_map in attention_paths_to_split.items(): + old_tensor = old_checkpoint[path] + channels = old_tensor.shape[0] // 3 + + target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1) + + num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3 + + old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]) + query, key, value = old_tensor.split(channels // num_heads, dim=1) + + checkpoint[path_map["query"]] = query.reshape(target_shape) + checkpoint[path_map["key"]] = key.reshape(target_shape) + checkpoint[path_map["value"]] = value.reshape(target_shape) + + for path in paths: + new_path = path["new"] + + # These have already been assigned + if attention_paths_to_split is not None and new_path in attention_paths_to_split: + continue + + # Global renaming happens here + new_path = new_path.replace("middle_block.0", "mid_block.resnets.0") + new_path = new_path.replace("middle_block.1", "mid_block.attentions.0") + new_path = new_path.replace("middle_block.2", "mid_block.resnets.1") + + if additional_replacements is not None: + for replacement in additional_replacements: + new_path = new_path.replace(replacement["old"], replacement["new"]) + + # proj_attn.weight has to be converted from conv 1D to linear + reshaping = False + if diffusers.__version__ < "0.17.0": + if "proj_attn.weight" in new_path: + reshaping = True + else: + if ".attentions." in new_path and ".0.to_" in new_path and old_checkpoint[path["old"]].ndim > 2: + reshaping = True + + if reshaping: + checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0] + else: + checkpoint[new_path] = old_checkpoint[path["old"]] + + +def conv_attn_to_linear(checkpoint): + keys = list(checkpoint.keys()) + attn_keys = ["query.weight", "key.weight", "value.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in attn_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + elif "proj_attn.weight" in key: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0] + + +def linear_transformer_to_conv(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim == 2: + checkpoint[key] = checkpoint[key].unsqueeze(2).unsqueeze(2) + + +def convert_ldm_unet_checkpoint(v2, checkpoint, config): + """ + Takes a state dict and a config, and returns a converted checkpoint. + """ + + # extract state_dict for UNet + unet_state_dict = {} + unet_key = "model.diffusion_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(unet_key): + unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key) + + new_checkpoint = {} + + new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"] + new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"] + new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"] + new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"] + + new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"] + new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"] + + new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"] + new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"] + new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"] + new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"] + + # Retrieves the keys for the input blocks only + num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer}) + input_blocks = { + layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}." in key] for layer_id in range(num_input_blocks) + } + + # Retrieves the keys for the middle blocks only + num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer}) + middle_blocks = { + layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}." in key] for layer_id in range(num_middle_blocks) + } + + # Retrieves the keys for the output blocks only + num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer}) + output_blocks = { + layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}." in key] for layer_id in range(num_output_blocks) + } + + for i in range(1, num_input_blocks): + block_id = (i - 1) // (config["layers_per_block"] + 1) + layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1) + + resnets = [key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key] + attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key] + + if f"input_blocks.{i}.0.op.weight" in unet_state_dict: + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop( + f"input_blocks.{i}.0.op.weight" + ) + new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(f"input_blocks.{i}.0.op.bias") + + paths = renew_resnet_paths(resnets) + meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + resnet_0 = middle_blocks[0] + attentions = middle_blocks[1] + resnet_1 = middle_blocks[2] + + resnet_0_paths = renew_resnet_paths(resnet_0) + assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config) + + resnet_1_paths = renew_resnet_paths(resnet_1) + assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config) + + attentions_paths = renew_attention_paths(attentions) + meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + for i in range(num_output_blocks): + block_id = i // (config["layers_per_block"] + 1) + layer_in_block_id = i % (config["layers_per_block"] + 1) + output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]] + output_block_list = {} + + for layer in output_block_layers: + layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1) + if layer_id in output_block_list: + output_block_list[layer_id].append(layer_name) + else: + output_block_list[layer_id] = [layer_name] + + if len(output_block_list) > 1: + resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key] + attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key] + + resnet_0_paths = renew_resnet_paths(resnets) + paths = renew_resnet_paths(resnets) + + meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"} + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + + # オリジナル: + # if ["conv.weight", "conv.bias"] in output_block_list.values(): + # index = list(output_block_list.values()).index(["conv.weight", "conv.bias"]) + + # biasとweightの順番に依存しないようにする:もっといいやり方がありそうだが + for l in output_block_list.values(): + l.sort() + + if ["conv.bias", "conv.weight"] in output_block_list.values(): + index = list(output_block_list.values()).index(["conv.bias", "conv.weight"]) + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.bias" + ] + new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[ + f"output_blocks.{i}.{index}.conv.weight" + ] + + # Clear attentions as they have been attributed above. + if len(attentions) == 2: + attentions = [] + + if len(attentions): + paths = renew_attention_paths(attentions) + meta_path = { + "old": f"output_blocks.{i}.1", + "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}", + } + assign_to_checkpoint(paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config) + else: + resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1) + for path in resnet_0_paths: + old_path = ".".join(["output_blocks", str(i), path["old"]]) + new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]]) + + new_checkpoint[new_path] = unet_state_dict[old_path] + + # SDのv2では1*1のconv2dがlinearに変わっている + # 誤って Diffusers 側を conv2d のままにしてしまったので、変換必要 + if v2 and not config.get("use_linear_projection", False): + linear_transformer_to_conv(new_checkpoint) + + return new_checkpoint + + +def convert_ldm_vae_checkpoint(checkpoint, config): + # extract state dict for VAE + vae_state_dict = {} + vae_key = "first_stage_model." + keys = list(checkpoint.keys()) + for key in keys: + if key.startswith(vae_key): + vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key) + # if len(vae_state_dict) == 0: + # # 渡されたcheckpointは.ckptから読み込んだcheckpointではなくvaeのstate_dict + # vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = {layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)} + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = {layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)} + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def create_unet_diffusers_config(v2, use_linear_projection_in_v2=False): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # unet_params = original_config.model.params.unet_config.params + + block_out_channels = [UNET_PARAMS_MODEL_CHANNELS * mult for mult in UNET_PARAMS_CHANNEL_MULT] + + down_block_types = [] + resolution = 1 + for i in range(len(block_out_channels)): + block_type = "CrossAttnDownBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "DownBlock2D" + down_block_types.append(block_type) + if i != len(block_out_channels) - 1: + resolution *= 2 + + up_block_types = [] + for i in range(len(block_out_channels)): + block_type = "CrossAttnUpBlock2D" if resolution in UNET_PARAMS_ATTENTION_RESOLUTIONS else "UpBlock2D" + up_block_types.append(block_type) + resolution //= 2 + + config = dict( + sample_size=UNET_PARAMS_IMAGE_SIZE, + in_channels=UNET_PARAMS_IN_CHANNELS, + out_channels=UNET_PARAMS_OUT_CHANNELS, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + layers_per_block=UNET_PARAMS_NUM_RES_BLOCKS, + cross_attention_dim=UNET_PARAMS_CONTEXT_DIM if not v2 else V2_UNET_PARAMS_CONTEXT_DIM, + attention_head_dim=UNET_PARAMS_NUM_HEADS if not v2 else V2_UNET_PARAMS_ATTENTION_HEAD_DIM, + # use_linear_projection=UNET_PARAMS_USE_LINEAR_PROJECTION if not v2 else V2_UNET_PARAMS_USE_LINEAR_PROJECTION, + ) + if v2 and use_linear_projection_in_v2: + config["use_linear_projection"] = True + + return config + + +def create_vae_diffusers_config(): + """ + Creates a config for the diffusers based on the config of the LDM model. + """ + # vae_params = original_config.model.params.first_stage_config.params.ddconfig + # _ = original_config.model.params.first_stage_config.params.embed_dim + block_out_channels = [VAE_PARAMS_CH * mult for mult in VAE_PARAMS_CH_MULT] + down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels) + up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels) + + config = dict( + sample_size=VAE_PARAMS_RESOLUTION, + in_channels=VAE_PARAMS_IN_CHANNELS, + out_channels=VAE_PARAMS_OUT_CH, + down_block_types=tuple(down_block_types), + up_block_types=tuple(up_block_types), + block_out_channels=tuple(block_out_channels), + latent_channels=VAE_PARAMS_Z_CHANNELS, + layers_per_block=VAE_PARAMS_NUM_RES_BLOCKS, + ) + return config + + +def convert_ldm_clip_checkpoint_v1(checkpoint): + keys = list(checkpoint.keys()) + text_model_dict = {} + for key in keys: + if key.startswith("cond_stage_model.transformer"): + text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key] + + # support checkpoint without position_ids (invalid checkpoint) + if "text_model.embeddings.position_ids" not in text_model_dict: + text_model_dict["text_model.embeddings.position_ids"] = torch.arange(77).unsqueeze(0) # 77 is the max length of the text + + return text_model_dict + + +def convert_ldm_clip_checkpoint_v2(checkpoint, max_length): + # 嫌になるくらい違うぞ! + def convert_key(key): + if not key.startswith("cond_stage_model"): + return None + + # common conversion + key = key.replace("cond_stage_model.model.transformer.", "text_model.encoder.") + key = key.replace("cond_stage_model.model.", "text_model.") + + if "resblocks" in key: + # resblocks conversion + key = key.replace(".resblocks.", ".layers.") + if ".ln_" in key: + key = key.replace(".ln_", ".layer_norm") + elif ".mlp." in key: + key = key.replace(".c_fc.", ".fc1.") + key = key.replace(".c_proj.", ".fc2.") + elif ".attn.out_proj" in key: + key = key.replace(".attn.out_proj.", ".self_attn.out_proj.") + elif ".attn.in_proj" in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in SD: {key}") + elif ".positional_embedding" in key: + key = key.replace(".positional_embedding", ".embeddings.position_embedding.weight") + elif ".text_projection" in key: + key = None # 使われない??? + elif ".logit_scale" in key: + key = None # 使われない??? + elif ".token_embedding" in key: + key = key.replace(".token_embedding.weight", ".embeddings.token_embedding.weight") + elif ".ln_final" in key: + key = key.replace(".ln_final", ".final_layer_norm") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + # remove resblocks 23 + if ".resblocks.23." in key: + continue + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if ".resblocks.23." in key: + continue + if ".resblocks" in key and ".attn.in_proj_" in key: + # 三つに分割 + values = torch.chunk(checkpoint[key], 3) + + key_suffix = ".weight" if "weight" in key else ".bias" + key_pfx = key.replace("cond_stage_model.model.transformer.resblocks.", "text_model.encoder.layers.") + key_pfx = key_pfx.replace("_weight", "") + key_pfx = key_pfx.replace("_bias", "") + key_pfx = key_pfx.replace(".attn.in_proj", ".self_attn.") + new_sd[key_pfx + "q_proj" + key_suffix] = values[0] + new_sd[key_pfx + "k_proj" + key_suffix] = values[1] + new_sd[key_pfx + "v_proj" + key_suffix] = values[2] + + # rename or add position_ids + ANOTHER_POSITION_IDS_KEY = "text_model.encoder.text_model.embeddings.position_ids" + if ANOTHER_POSITION_IDS_KEY in new_sd: + # waifu diffusion v1.4 + position_ids = new_sd[ANOTHER_POSITION_IDS_KEY] + del new_sd[ANOTHER_POSITION_IDS_KEY] + else: + position_ids = torch.Tensor([list(range(max_length))]).to(torch.int64) + + new_sd["text_model.embeddings.position_ids"] = position_ids + return new_sd + + +# endregion + + +# region Diffusers->StableDiffusion の変換コード +# convert_diffusers_to_original_stable_diffusion をコピーして修正している(ASL 2.0) + + +def conv_transformer_to_linear(checkpoint): + keys = list(checkpoint.keys()) + tf_keys = ["proj_in.weight", "proj_out.weight"] + for key in keys: + if ".".join(key.split(".")[-2:]) in tf_keys: + if checkpoint[key].ndim > 2: + checkpoint[key] = checkpoint[key][:, :, 0, 0] + + +def convert_unet_state_dict_to_sd(v2, unet_state_dict): + unet_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("out.0.weight", "conv_norm_out.weight"), + ("out.0.bias", "conv_norm_out.bias"), + ("out.2.weight", "conv_out.weight"), + ("out.2.bias", "conv_out.bias"), + ] + + unet_conversion_map_resnet = [ + # (stable-diffusion, HF Diffusers) + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + # loop over downblocks/upblocks + + for j in range(2): + # loop over resnets/attentions for downblocks + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + # no attention layers in down_blocks.3 + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + for j in range(3): + # loop over resnets/attentions for upblocks + hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}." + sd_up_res_prefix = f"output_blocks.{3*i + j}.0." + unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix)) + + if i > 0: + # no attention layers in up_blocks.0 + hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}." + sd_up_atn_prefix = f"output_blocks.{3*i + j}.1." + unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix)) + + if i < 3: + # no downsample in down_blocks.3 + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + # no upsample in up_blocks.3 + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}." + unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + # buyer beware: this is a *brittle* function, + # and correct output requires that all of these pieces interact in + # the exact order in which I have arranged them. + mapping = {k: k for k in unet_state_dict.keys()} + for sd_name, hf_name in unet_conversion_map: + mapping[hf_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, hf_part in unet_conversion_map_resnet: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, hf_part in unet_conversion_map_layer: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: unet_state_dict[k] for k, v in mapping.items()} + + if v2: + conv_transformer_to_linear(new_state_dict) + + return new_state_dict + + +def controlnet_conversion_map(): + unet_conversion_map = [ + ("time_embed.0.weight", "time_embedding.linear_1.weight"), + ("time_embed.0.bias", "time_embedding.linear_1.bias"), + ("time_embed.2.weight", "time_embedding.linear_2.weight"), + ("time_embed.2.bias", "time_embedding.linear_2.bias"), + ("input_blocks.0.0.weight", "conv_in.weight"), + ("input_blocks.0.0.bias", "conv_in.bias"), + ("middle_block_out.0.weight", "controlnet_mid_block.weight"), + ("middle_block_out.0.bias", "controlnet_mid_block.bias"), + ] + + unet_conversion_map_resnet = [ + ("in_layers.0", "norm1"), + ("in_layers.2", "conv1"), + ("out_layers.0", "norm2"), + ("out_layers.3", "conv2"), + ("emb_layers.1", "time_emb_proj"), + ("skip_connection", "conv_shortcut"), + ] + + unet_conversion_map_layer = [] + for i in range(4): + for j in range(2): + hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}." + sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0." + unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix)) + + if i < 3: + hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}." + sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1." + unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv." + sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op." + unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_mid_atn_prefix = "mid_block.attentions.0." + sd_mid_atn_prefix = "middle_block.1." + unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix)) + + for j in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{j}." + sd_mid_res_prefix = f"middle_block.{2*j}." + unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + controlnet_cond_embedding_names = ["conv_in"] + [f"blocks.{i}" for i in range(6)] + ["conv_out"] + for i, hf_prefix in enumerate(controlnet_cond_embedding_names): + hf_prefix = f"controlnet_cond_embedding.{hf_prefix}." + sd_prefix = f"input_hint_block.{i*2}." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + for i in range(12): + hf_prefix = f"controlnet_down_blocks.{i}." + sd_prefix = f"zero_convs.{i}.0." + unet_conversion_map_layer.append((sd_prefix, hf_prefix)) + + return unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer + + +def convert_controlnet_state_dict_to_sd(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[diffusers_name] = sd_name + for k, v in mapping.items(): + if "resnets" in k: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(diffusers_part, sd_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +def convert_controlnet_state_dict_to_diffusers(controlnet_state_dict): + unet_conversion_map, unet_conversion_map_resnet, unet_conversion_map_layer = controlnet_conversion_map() + + mapping = {k: k for k in controlnet_state_dict.keys()} + for sd_name, diffusers_name in unet_conversion_map: + mapping[sd_name] = diffusers_name + for k, v in mapping.items(): + for sd_part, diffusers_part in unet_conversion_map_layer: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + for k, v in mapping.items(): + if "resnets" in v: + for sd_part, diffusers_part in unet_conversion_map_resnet: + v = v.replace(sd_part, diffusers_part) + mapping[k] = v + new_state_dict = {v: controlnet_state_dict[k] for k, v in mapping.items()} + return new_state_dict + + +# ================# +# VAE Conversion # +# ================# + + +def reshape_weight_for_sd(w): + # convert HF linear weights to SD conv2d weights + return w.reshape(*w.shape, 1, 1) + + +def convert_vae_state_dict(vae_state_dict): + vae_conversion_map = [ + # (stable-diffusion, HF Diffusers) + ("nin_shortcut", "conv_shortcut"), + ("norm_out", "conv_norm_out"), + ("mid.attn_1.", "mid_block.attentions.0."), + ] + + for i in range(4): + # down_blocks have two resnets + for j in range(2): + hf_down_prefix = f"encoder.down_blocks.{i}.resnets.{j}." + sd_down_prefix = f"encoder.down.{i}.block.{j}." + vae_conversion_map.append((sd_down_prefix, hf_down_prefix)) + + if i < 3: + hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0." + sd_downsample_prefix = f"down.{i}.downsample." + vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix)) + + hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0." + sd_upsample_prefix = f"up.{3-i}.upsample." + vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix)) + + # up_blocks have three resnets + # also, up blocks in hf are numbered in reverse from sd + for j in range(3): + hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}." + sd_up_prefix = f"decoder.up.{3-i}.block.{j}." + vae_conversion_map.append((sd_up_prefix, hf_up_prefix)) + + # this part accounts for mid blocks in both the encoder and the decoder + for i in range(2): + hf_mid_res_prefix = f"mid_block.resnets.{i}." + sd_mid_res_prefix = f"mid.block_{i+1}." + vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix)) + + if diffusers.__version__ < "0.17.0": + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "query."), + ("k.", "key."), + ("v.", "value."), + ("proj_out.", "proj_attn."), + ] + else: + vae_conversion_map_attn = [ + # (stable-diffusion, HF Diffusers) + ("norm.", "group_norm."), + ("q.", "to_q."), + ("k.", "to_k."), + ("v.", "to_v."), + ("proj_out.", "to_out.0."), + ] + + mapping = {k: k for k in vae_state_dict.keys()} + for k, v in mapping.items(): + for sd_part, hf_part in vae_conversion_map: + v = v.replace(hf_part, sd_part) + mapping[k] = v + for k, v in mapping.items(): + if "attentions" in k: + for sd_part, hf_part in vae_conversion_map_attn: + v = v.replace(hf_part, sd_part) + mapping[k] = v + new_state_dict = {v: vae_state_dict[k] for k, v in mapping.items()} + weights_to_convert = ["q", "k", "v", "proj_out"] + for k, v in new_state_dict.items(): + for weight_name in weights_to_convert: + if f"mid.attn_1.{weight_name}.weight" in k: + # print(f"Reshaping {k} for SD format: shape {v.shape} -> {v.shape} x 1 x 1") + new_state_dict[k] = reshape_weight_for_sd(v) + + return new_state_dict + + +# endregion + +# region 自作のモデル読み書きなど + + +def is_safetensors(path): + return os.path.splitext(path)[1].lower() == ".safetensors" + + +def load_checkpoint_with_text_encoder_conversion(ckpt_path, device="cpu"): + # text encoderの格納形式が違うモデルに対応する ('text_model'がない) + TEXT_ENCODER_KEY_REPLACEMENTS = [ + ("cond_stage_model.transformer.embeddings.", "cond_stage_model.transformer.text_model.embeddings."), + ("cond_stage_model.transformer.encoder.", "cond_stage_model.transformer.text_model.encoder."), + ("cond_stage_model.transformer.final_layer_norm.", "cond_stage_model.transformer.text_model.final_layer_norm."), + ] + + if is_safetensors(ckpt_path): + checkpoint = None + state_dict = load_file(ckpt_path) # , device) # may causes error + else: + checkpoint = torch.load(ckpt_path, map_location=device) + if "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + checkpoint = None + + key_reps = [] + for rep_from, rep_to in TEXT_ENCODER_KEY_REPLACEMENTS: + for key in state_dict.keys(): + if key.startswith(rep_from): + new_key = rep_to + key[len(rep_from) :] + key_reps.append((key, new_key)) + + for key, new_key in key_reps: + state_dict[new_key] = state_dict[key] + del state_dict[key] + + return checkpoint, state_dict + + +def get_model_version_str_for_sd1_sd2(v2, v_parameterization): + # only for reference + version_str = "sd" + if v2: + version_str += "_v2" + else: + version_str += "_v1" + if v_parameterization: + version_str += "_v" + return version_str + + +def convert_text_encoder_state_dict_to_sd_v2(checkpoint, make_dummy_weights=False): + def convert_key(key): + # position_idsの除去 + if ".position_ids" in key: + return None + + # common + key = key.replace("text_model.encoder.", "transformer.") + key = key.replace("text_model.", "") + if "layers" in key: + # resblocks conversion + key = key.replace(".layers.", ".resblocks.") + if ".layer_norm" in key: + key = key.replace(".layer_norm", ".ln_") + elif ".mlp." in key: + key = key.replace(".fc1.", ".c_fc.") + key = key.replace(".fc2.", ".c_proj.") + elif ".self_attn.out_proj" in key: + key = key.replace(".self_attn.out_proj.", ".attn.out_proj.") + elif ".self_attn." in key: + key = None # 特殊なので後で処理する + else: + raise ValueError(f"unexpected key in DiffUsers model: {key}") + elif ".position_embedding" in key: + key = key.replace("embeddings.position_embedding.weight", "positional_embedding") + elif ".token_embedding" in key: + key = key.replace("embeddings.token_embedding.weight", "token_embedding.weight") + elif "final_layer_norm" in key: + key = key.replace("final_layer_norm", "ln_final") + return key + + keys = list(checkpoint.keys()) + new_sd = {} + for key in keys: + new_key = convert_key(key) + if new_key is None: + continue + new_sd[new_key] = checkpoint[key] + + # attnの変換 + for key in keys: + if "layers" in key and "q_proj" in key: + # 三つを結合 + key_q = key + key_k = key.replace("q_proj", "k_proj") + key_v = key.replace("q_proj", "v_proj") + + value_q = checkpoint[key_q] + value_k = checkpoint[key_k] + value_v = checkpoint[key_v] + value = torch.cat([value_q, value_k, value_v]) + + new_key = key.replace("text_model.encoder.layers.", "transformer.resblocks.") + new_key = new_key.replace(".self_attn.q_proj.", ".attn.in_proj_") + new_sd[new_key] = value + + # 最後の層などを捏造するか + if make_dummy_weights: + print("make dummy weights for resblock.23, text_projection and logit scale.") + keys = list(new_sd.keys()) + for key in keys: + if key.startswith("transformer.resblocks.22."): + new_sd[key.replace(".22.", ".23.")] = new_sd[key].clone() # copyしないとsafetensorsの保存で落ちる + + # Diffusersに含まれない重みを作っておく + new_sd["text_projection"] = torch.ones((1024, 1024), dtype=new_sd[keys[0]].dtype, device=new_sd[keys[0]].device) + new_sd["logit_scale"] = torch.tensor(1) + + return new_sd + + +def save_stable_diffusion_checkpoint( + v2, output_file, text_encoder, unet, ckpt_path, epochs, steps, metadata, save_dtype=None, vae=None +): + if ckpt_path is not None: + # epoch/stepを参照する。またVAEがメモリ上にないときなど、もう一度VAEを含めて読み込む + checkpoint, state_dict = load_checkpoint_with_text_encoder_conversion(ckpt_path) + if checkpoint is None: # safetensors または state_dictのckpt + checkpoint = {} + strict = False + else: + strict = True + if "state_dict" in state_dict: + del state_dict["state_dict"] + else: + # 新しく作る + assert vae is not None, "VAE is required to save a checkpoint without a given checkpoint" + checkpoint = {} + state_dict = {} + strict = False + + def update_sd(prefix, sd): + for k, v in sd.items(): + key = prefix + k + assert not strict or key in state_dict, f"Illegal key in save SD: {key}" + if save_dtype is not None: + v = v.detach().clone().to("cpu").to(save_dtype) + state_dict[key] = v + + # Convert the UNet model + unet_state_dict = convert_unet_state_dict_to_sd(v2, unet.state_dict()) + update_sd("model.diffusion_model.", unet_state_dict) + + # Convert the text encoder model + if v2: + make_dummy = ckpt_path is None # 参照元のcheckpointがない場合は最後の層を前の層から複製して作るなどダミーの重みを入れる + text_enc_dict = convert_text_encoder_state_dict_to_sd_v2(text_encoder.state_dict(), make_dummy) + update_sd("cond_stage_model.model.", text_enc_dict) + else: + text_enc_dict = text_encoder.state_dict() + update_sd("cond_stage_model.transformer.", text_enc_dict) + + # Convert the VAE + if vae is not None: + vae_dict = convert_vae_state_dict(vae.state_dict()) + update_sd("first_stage_model.", vae_dict) + + # Put together new checkpoint + key_count = len(state_dict.keys()) + new_ckpt = {"state_dict": state_dict} + + # epoch and global_step are sometimes not int + try: + if "epoch" in checkpoint: + epochs += checkpoint["epoch"] + if "global_step" in checkpoint: + steps += checkpoint["global_step"] + except: + pass + + new_ckpt["epoch"] = epochs + new_ckpt["global_step"] = steps + + if is_safetensors(output_file): + # TODO Tensor以外のdictの値を削除したほうがいいか + save_file(state_dict, output_file, metadata) + else: + torch.save(new_ckpt, output_file) + + return key_count + + +def save_diffusers_checkpoint(v2, output_dir, text_encoder, unet, pretrained_model_name_or_path, vae=None, use_safetensors=False): + if pretrained_model_name_or_path is None: + # load default settings for v1/v2 + if v2: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V2 + else: + pretrained_model_name_or_path = DIFFUSERS_REF_MODEL_ID_V1 + + scheduler = DDIMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer") + if vae is None: + vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae") + + pipeline = StableDiffusionPipeline( + unet=unet, + text_encoder=text_encoder, + vae=vae, + scheduler=scheduler, + tokenizer=tokenizer, + safety_checker=None, + feature_extractor=None, + requires_safety_checker=None, + ) + pipeline.save_pretrained(output_dir, safe_serialization=use_safetensors) + + +VAE_PREFIX = "first_stage_model." + + +def load_vae(vae_id, dtype): + print(f"load VAE: {vae_id}") + if os.path.isdir(vae_id) or not os.path.isfile(vae_id): + # Diffusers local/remote + try: + vae = AutoencoderKL.from_pretrained(vae_id, subfolder=None, torch_dtype=dtype) + except EnvironmentError as e: + print(f"exception occurs in loading vae: {e}") + print("retry with subfolder='vae'") + vae = AutoencoderKL.from_pretrained(vae_id, subfolder="vae", torch_dtype=dtype) + return vae + + # local + vae_config = create_vae_diffusers_config() + + if vae_id.endswith(".bin"): + # SD 1.5 VAE on Huggingface + converted_vae_checkpoint = torch.load(vae_id, map_location="cpu") + else: + # StableDiffusion + vae_model = load_file(vae_id, "cpu") if is_safetensors(vae_id) else torch.load(vae_id, map_location="cpu") + vae_sd = vae_model["state_dict"] if "state_dict" in vae_model else vae_model + + # vae only or full model + full_model = False + for vae_key in vae_sd: + if vae_key.startswith(VAE_PREFIX): + full_model = True + break + if not full_model: + sd = {} + for key, value in vae_sd.items(): + sd[VAE_PREFIX + key] = value + vae_sd = sd + del sd + + # Convert the VAE model. + converted_vae_checkpoint = convert_ldm_vae_checkpoint(vae_sd, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + return vae + + +# load state_dict without allocating new tensors +def _load_state_dict_on_device(model, state_dict, device, dtype=None): + # dtype will use fp32 as default + missing_keys = list(model.state_dict().keys() - state_dict.keys()) + unexpected_keys = list(state_dict.keys() - model.state_dict().keys()) + + # similar to model.load_state_dict() + if not missing_keys and not unexpected_keys: + for k in list(state_dict.keys()): + set_module_tensor_to_device(model, k, device, value=state_dict.pop(k), dtype=dtype) + return "