support sdxl

This commit is contained in:
limbo0000
2023-11-10 11:57:39 +08:00
parent 60dfd554c0
commit d6f459dbd6
111 changed files with 5620 additions and 3750 deletions

313
README.md
View File

@@ -1,121 +1,10 @@
# AnimateDiff
# A beta-version of motion module for SDXL
This repository is the official implementation of [AnimateDiff](https://arxiv.org/abs/2307.04725).
**[AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)**
</br>
Yuwei Guo,
Ceyuan Yang*,
Anyi Rao,
Yaohui Wang,
Yu Qiao,
Dahua Lin,
Bo Dai
<p style="font-size: 0.8em; margin-top: -1em">*Corresponding Author</p>
<!-- [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) -->
[![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2307.04725)
[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://animatediff.github.io/)
[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg)](https://openxlab.org.cn/apps/detail/Masbfca/AnimateDiff)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-yellow)](https://huggingface.co/spaces/guoyww/AnimateDiff)
## Next
One with better controllability and quality is coming soon. Stay tuned.
## Features
- **[2023/11/10]** Release the Motion Module (beta version) on SDXL, available at [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link
) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt
) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules). High resolution videos (i.e., 1024x1024x16 frames with various aspect ratios) could be produced **with/without** personalized models. Inference usually requires ~13GB VRAM and tuned hyperparameters (e.g., #sampling steps), depending on the chosen personalized models. Checkout to the branch `sdxl` for more details of the inference. More checkpoints with better-quality would be available soon. Stay tuned. Examples below are manually downsampled for fast loading.
<table class="center">
<tr style="line-height: 0">
<td width=50% style="border: none; text-align: center">Original SDXL</td>
<td width=30% style="border: none; text-align: center">Personalized SDXL</td>
<td width=20% style="border: none; text-align: center">Personalized SDXL</td>
</tr>
<tr>
<td width=50% style="border: none"><img src="__assets__/animations/motion_xl/01.gif"></td>
<td width=30% style="border: none"><img src="__assets__/animations/motion_xl/02.gif"></td>
<td width=20% style="border: none"><img src="__assets__/animations/motion_xl/03.gif"></td>
</tr>
</table>
- **[2023/09/25]** Release **MotionLoRA** and its model zoo, **enabling camera movement controls**! Please download the MotionLoRA models (**74 MB per model**, available at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules) ) and save them to the `models/MotionLoRA` folder. Example:
```
python -m scripts.animate --config configs/prompts/v2/5-RealisticVision-MotionLoRA.yaml
```
<table class="center">
<tr style="line-height: 0">
<td colspan="2" style="border: none; text-align: center">Zoom In</td>
<td colspan="2" style="border: none; text-align: center">Zoom Out</td>
<td colspan="2" style="border: none; text-align: center">Zoom Pan Left</td>
<td colspan="2" style="border: none; text-align: center">Zoom Pan Right</td>
</tr>
<tr>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/01.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/02.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/02.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/01.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/03.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/04.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/04.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/03.gif"></td>
</tr>
<tr style="line-height: 0">
<td colspan="2" style="border: none; text-align: center">Tilt Up</td>
<td colspan="2" style="border: none; text-align: center">Tilt Down</td>
<td colspan="2" style="border: none; text-align: center">Rolling Anti-Clockwise</td>
<td colspan="2" style="border: none; text-align: center">Rolling Clockwise</td>
</tr>
<tr>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/05.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/05.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/06.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/06.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/07.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/07.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_01/08.gif"></td>
<td style="border: none"><img src="__assets__/animations/motion_lora/model_02/08.gif"></td>
</tr>
</table>
- **[2023/09/10]** New Motion Module release! `mm_sd_v15_v2.ckpt` was trained on larger resolution & batch size, and gains noticeable quality improvements. Check it out at [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules) and use it with `configs/inference/inference-v2.yaml`. Example:
```
python -m scripts.animate --config configs/prompts/v2/5-RealisticVision.yaml
```
Here is a qualitative comparison between `mm_sd_v15.ckpt` (left) and `mm_sd_v15_v2.ckpt` (right):
<table class="center">
<tr>
<td><img src="__assets__/animations/compare/old_0.gif"></td>
<td><img src="__assets__/animations/compare/new_0.gif"></td>
<td><img src="__assets__/animations/compare/old_1.gif"></td>
<td><img src="__assets__/animations/compare/new_1.gif"></td>
<td><img src="__assets__/animations/compare/old_2.gif"></td>
<td><img src="__assets__/animations/compare/new_2.gif"></td>
<td><img src="__assets__/animations/compare/old_3.gif"></td>
<td><img src="__assets__/animations/compare/new_3.gif"></td>
</tr>
</table>
- GPU Memory Optimization, ~12GB VRAM to inference
## Quick Demo
User Interface developed by community:
- A1111 Extension [sd-webui-animatediff](https://github.com/continue-revolution/sd-webui-animatediff) (by [@continue-revolution](https://github.com/continue-revolution))
- ComfyUI Extension [ComfyUI-AnimateDiff-Evolved](https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) (by [@Kosinkadink](https://github.com/Kosinkadink))
- Google Colab: [Colab](https://colab.research.google.com/github/camenduru/AnimateDiff-colab/blob/main/AnimateDiff_colab.ipynb) (by [@camenduru](https://github.com/camenduru))
We also create a Gradio demo to make AnimateDiff easier to use. To launch the demo, please run the following commands:
```
conda activate animatediff
python app.py
```
By default, the demo will run at `localhost:7860`.
<br><img src="__assets__/figs/gradio.jpg" style="width: 50em; margin-top: 1em">
Now you can generate high-resolution videos on SDXL **with/without** personalized models. Checkpoint with better quality would be available soon. Stay tuned.
## Somethings Important
- Generate videos with high-resolution (we provide recommended ones) as SDXL usually leads to worse quality for low-resolution images.
- Follow and slightly adjust the hyperparameters (e.g., #sampling steps, #guidance scale) of various personalized SDXL since these models are carefully tuned to various extent.
## Model Zoo
<details open>
@@ -123,86 +12,140 @@ By default, the demo will run at `localhost:7860`.
| Name | Parameter | Storage Space |
|----------------------|-----------|---------------|
| mm_sd_v14.ckpt | 417 M | 1.6 GB |
| mm_sd_v15.ckpt | 417 M | 1.6 GB |
| mm_sd_v15_v2.ckpt | 453 M | 1.7 GB |
| mm_sdxl_v10_beta.ckpt | 238 M | 0.9 GB |
</details>
<details open>
<summary>MotionLoRAs</summary>
<summary>Recommended Resolution</summary>
| Name | Parameter | Storage Space |
|--------------------------------------|-----------|---------------|
| v2_lora_ZoomIn.ckpt | 19 M | 74 MB |
| v2_lora_ZoomOut.ckpt | 19 M | 74 MB |
| v2_lora_PanLeft.ckpt | 19 M | 74 MB |
| v2_lora_PanRight.ckpt | 19 M | 74 MB |
| v2_lora_TiltUp.ckpt | 19 M | 74 MB |
| v2_lora_TiltDown.ckpt | 19 M | 74 MB |
| v2_lora_RollingClockwise.ckpt | 19 M | 74 MB |
| v2_lora_RollingAnticlockwise.ckpt | 19 M | 74 MB |
| Resolution | Aspect Ratio |
|----------------------|-----------|
| 768x1344 | 9:16 |
| 832x1216 | 2:3 |
| 1024x1024 | 1:1 |
| 1216x832 | 3:2 |
| 1344x768 | 16:9 |
</details>
## Common Issues
<details>
<summary>Installation</summary>
Please ensure the installation of [xformer](https://github.com/facebookresearch/xformers) that is applied to reduce the inference memory.
</details>
<details>
<summary>Various resolution or number of frames</summary>
Currently, we recommend users to generate animation with 16 frames and 512 resolution that are aligned with our training settings. Notably, various resolution/frames may affect the quality more or less.
</details>
<details>
<summary>How to use it without any coding</summary>
1) Get lora models: train lora model with [A1111](https://github.com/continue-revolution/sd-webui-animatediff) based on a collection of your own favorite images (e.g., tutorials [English](https://www.youtube.com/watch?v=mfaqqL5yOO4), [Japanese](https://www.youtube.com/watch?v=N1tXVR9lplM), [Chinese](https://www.bilibili.com/video/BV1fs4y1x7p2/))
or download Lora models from [Civitai](https://civitai.com/).
2) Animate lora models: using gradio interface or A1111
(e.g., tutorials [English](https://github.com/continue-revolution/sd-webui-animatediff), [Japanese](https://www.youtube.com/watch?v=zss3xbtvOWw), [Chinese](https://941ai.com/sd-animatediff-webui-1203.html))
3) Be creative togther with other techniques, such as, super resolution, frame interpolation, music generation, etc.
</details>
<details>
<summary>Animating a given image</summary>
We totally agree that animating a given image is an appealing feature, which we would try to support officially in future. For now, you may enjoy other efforts from the [talesofai](https://github.com/talesofai/AnimateDiff).
</details>
<details>
<summary>Contributions from community</summary>
Contributions are always welcome!! The <code>dev</code> branch is for community contributions. As for the main branch, we would like to align it with the original technical report :)
</details>
## Training and inference
Please refer to [ANIMATEDIFF](./__assets__/docs/animatediff.md) for the detailed setup.
## Gallery
We collect several generated results in [GALLERY](./__assets__/docs/gallery.md).
We demonstrate some results with our model. The GIFs below are **manually downsampled** after generation for fast loading.
**Original SDXL**
<table class="center">
<tr>
<td><img src="__assets__/animations/model_original/01.gif"></td>
<td><img src="__assets__/animations/model_original/02.gif"></td>
</tr>
</table>
**LoRA**
<table class="center">
<tr>
<td><img src="__assets__/animations/model_01/01.gif"></td>
<td><img src="__assets__/animations/model_01/02.gif"></td>
<td><img src="__assets__/animations/model_01/03.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/122606?modelVersionId=169718">DynaVision</a><p>
<table class="center">
<tr>
<td><img src="__assets__/animations/model_02/01.gif"></td>
<td><img src="__assets__/animations/model_02/02.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/112902/dreamshaper-xl10?modelVersionId=126688">DreamShaper</a><p>
<table class="center">
<tr>
<td><img src="__assets__/animations/model_03/01.gif"></td>
<td><img src="__assets__/animations/model_03/02.gif"></td>
<td><img src="__assets__/animations/model_03/03.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/128397/deepblue-xl?modelVersionId=189102">DeepBlue</a><p>
## Inference Example
Inference at recommended resolution of 16 frames usually requires ~13GB VRAM.
### Step-1: Prepare Environment
## BibTeX
```
@article{guo2023animatediff,
title={AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning},
author={Guo, Yuwei and Yang, Ceyuan and Rao, Anyi and Wang, Yaohui and Qiao, Yu and Lin, Dahua and Dai, Bo},
journal={arXiv preprint arXiv:2307.04725},
year={2023}
}
git clone https://github.com/guoyww/AnimateDiff.git
cd AnimateDiff
git checkout sdxl
conda env create -f environment.yaml
conda activate animatediff_xl
```
## Contact Us
**Yuwei Guo**: [guoyuwei@pjlab.org.cn](mailto:guoyuwei@pjlab.org.cn)
**Ceyuan Yang**: [yangceyuan@pjlab.org.cn](mailto:yangceyuan@pjlab.org.cn)
**Bo Dai**: [daibo@pjlab.org.cn](mailto:daibo@pjlab.org.cn)
### Step-2: Download Base T2I & Motion Module Checkpoints
We provide a beta version of motion module on SDXL. You can download the base model of SDXL 1.0 and Motion Module following instructions below.
```
git lfs install
git clone https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0 models/StableDiffusion/
## Acknowledgements
Codebase built upon [Tune-a-Video](https://github.com/showlab/Tune-A-Video).
bash download_bashscripts/0-MotionModule.sh
```
You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/file/d/1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX/view?usp=share_link
) / [HuggingFace](https://huggingface.co/guoyww/animatediff/blob/main/mm_sdxl_v10_beta.ckpt
) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder.
### Step-3: Download Personalized SDXL (you can skip this if generating videos on the original SDXL)
You may run the following bash scripts to download the LoRA checkpoint from CivitAI.
```
bash download_bashscripts/1-DynaVision.sh
bash download_bashscripts/2-DreamShaper.sh
bash download_bashscripts/3-DeepBlue.sh
```
### Step-4: Generate Videos
Run the following commands to generate videos of **original SDXL**.
```
python -m scripts.animate --exp_config configs/prompts/1-original_sdxl.yaml --H 1024 --W 1024 --L 16 --xformers
```
Run the following commands to generate videos of **personalized SDXL**. DO NOT skip Step-3.
```
python -m scripts.animate --config configs/prompts/2-DynaVision.yaml --H 1024 --W 1024 --L 16 --xformers
python -m scripts.animate --config configs/prompts/3-DreamShaper.yaml --H 1024 --W 1024 --L 16 --xformers
python -m scripts.animate --config configs/prompts/4-DeepBlue.yaml --H 1024 --W 1024 --L 16 --xformers
```
The results will automatically be saved to `samples/` folder.
## Customized Inference
To generate videos with a new Checkpoint/LoRA model, you may create a new config `.yaml` file in the following format:
```
motion_module_path: "models/Motion_Module/mm_sdxl_v10_beta.ckpt" # Specify the Motion Module
# We support 3 types of T2I models.
# 1. Checkpoint: a safetensors model contains UNet, Text_Encoders, VAE.
# 2. LoRA: a safetensors model contains only the LoRA modules.
# 3. You can convert the Checkpoint into a folder with the same structure as SDXL_1.0 base model.
ckpt_path: "YOUR_CKPT_PATH" # path to the checkpoint type model from CivitAI.
lora_path: "YOUR_LORA_PATH" # path to the LORA type model from CivitAI.
base_model_path: "YOUR_BASE_MODEL_PATH" # path to the folder converted from a checkpoint
steps: 50
guidance_scale: 8.5
seed: -1 # You can specify seed for each prompt.
prompt:
- "[positive prompt]"
n_prompt:
- "[negative prompt]"
```
Then run the following commands.
```
python -m scripts.animate --exp_config [path to the personalized config] --L [video frames] --H [Height of the videos] --W [Width of the videos] --xformers
```

Binary file not shown.

Before

Width:  |  Height:  |  Size: 972 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 591 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 766 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.1 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 862 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 639 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 677 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.0 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 357 KiB

After

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 352 KiB

After

Width:  |  Height:  |  Size: 1.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 223 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 339 KiB

After

Width:  |  Height:  |  Size: 710 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 410 KiB

After

Width:  |  Height:  |  Size: 817 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 429 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 454 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 280 KiB

After

Width:  |  Height:  |  Size: 1.6 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 485 KiB

After

Width:  |  Height:  |  Size: 2.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 287 KiB

After

Width:  |  Height:  |  Size: 2.5 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 293 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 230 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 353 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 386 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 351 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 379 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 307 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 258 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 221 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 213 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 184 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 396 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.8 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.2 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.3 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 103 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.9 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 2.0 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.7 MiB

View File

Before

Width:  |  Height:  |  Size: 741 KiB

After

Width:  |  Height:  |  Size: 741 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1020 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 591 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 596 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 601 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 592 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 553 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 561 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 552 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 548 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 543 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 537 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 559 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 547 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 556 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 544 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 544 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 566 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.4 MiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 1.6 MiB

View File

@@ -1,112 +0,0 @@
# AnimateDiff: training and inference setup
## Setups for Inference
### Prepare Environment
***We updated our inference code with xformers and a sequential decoding trick. Now AnimateDiff takes only ~12GB VRAM to inference, and run on a single RTX3090 !!***
```
git clone https://github.com/guoyww/AnimateDiff.git
cd AnimateDiff
conda env create -f environment.yaml
conda activate animatediff
```
### Download Base T2I & Motion Module Checkpoints
We provide two versions of our Motion Module, which are trained on stable-diffusion-v1-4 and finetuned on v1-5 seperately.
It's recommanded to try both of them for best results.
```
git lfs install
git clone https://huggingface.co/runwayml/stable-diffusion-v1-5 models/StableDiffusion/
bash download_bashscripts/0-MotionModule.sh
```
You may also directly download the motion module checkpoints from [Google Drive](https://drive.google.com/drive/folders/1EqLC65eR1-W-sGD0Im7fkED6c8GkiNFI?usp=sharing) / [HuggingFace](https://huggingface.co/guoyww/animatediff) / [CivitAI](https://civitai.com/models/108836/animatediff-motion-modules), then put them in `models/Motion_Module/` folder.
### Prepare Personalize T2I
Here we provide inference configs for 6 demo T2I on CivitAI.
You may run the following bash scripts to download these checkpoints.
```
bash download_bashscripts/1-ToonYou.sh
bash download_bashscripts/2-Lyriel.sh
bash download_bashscripts/3-RcnzCartoon.sh
bash download_bashscripts/4-MajicMix.sh
bash download_bashscripts/5-RealisticVision.sh
bash download_bashscripts/6-Tusun.sh
bash download_bashscripts/7-FilmVelvia.sh
bash download_bashscripts/8-GhibliBackground.sh
```
### Inference
After downloading the above peronalized T2I checkpoints, run the following commands to generate animations. The results will automatically be saved to `samples/` folder.
```
python -m scripts.animate --config configs/prompts/1-ToonYou.yaml
python -m scripts.animate --config configs/prompts/2-Lyriel.yaml
python -m scripts.animate --config configs/prompts/3-RcnzCartoon.yaml
python -m scripts.animate --config configs/prompts/4-MajicMix.yaml
python -m scripts.animate --config configs/prompts/5-RealisticVision.yaml
python -m scripts.animate --config configs/prompts/6-Tusun.yaml
python -m scripts.animate --config configs/prompts/7-FilmVelvia.yaml
python -m scripts.animate --config configs/prompts/8-GhibliBackground.yaml
```
To generate animations with a new DreamBooth/LoRA model, you may create a new config `.yaml` file in the following format:
```
NewModel:
inference_config: "[path to motion module config file]"
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
motion_module_lora_configs:
- path: "[path to MotionLoRA model]"
alpha: 1.0
- ...
dreambooth_path: "[path to your DreamBooth model .safetensors file]"
lora_model_path: "[path to your LoRA model .safetensors file, leave it empty string if not needed]"
steps: 25
guidance_scale: 7.5
prompt:
- "[positive prompt]"
n_prompt:
- "[negative prompt]"
```
Then run the following commands:
```
python -m scripts.animate --config [path to the config file]
```
## Steps for Training
### Dataset
Before training, download the videos files and the `.csv` annotations of [WebVid10M](https://maxbain.com/webvid-dataset/) to the local mechine.
Note that our examplar training script requires all the videos to be saved in a single folder. You may change this by modifying `animatediff/data/dataset.py`.
### Configuration
After dataset preparations, update the below data paths in the config `.yaml` files in `configs/training/` folder:
```
train_data:
csv_path: [Replace with .csv Annotation File Path]
video_folder: [Replace with Video Folder Path]
sample_size: 256
```
Other training parameters (lr, epochs, validation settings, etc.) are also included in the config files.
### Training
To train motion modules
```
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/training.yaml
```
To finetune the unet's image layers
```
torchrun --nnodes=1 --nproc_per_node=1 train.py --config configs/training/image_finetune.yaml
```

View File

@@ -1,93 +0,0 @@
# Gallery
Here we demonstrate several best results we found in our experiments.
<table class="center">
<tr>
<td><img src="../animations/model_01/01.gif"></td>
<td><img src="../animations/model_01/02.gif"></td>
<td><img src="../animations/model_01/03.gif"></td>
<td><img src="../animations/model_01/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/30240/toonyou">ToonYou</a></p>
<table>
<tr>
<td><img src="../animations/model_02/01.gif"></td>
<td><img src="../animations/model_02/02.gif"></td>
<td><img src="../animations/model_02/03.gif"></td>
<td><img src="../animations/model_02/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4468/counterfeit-v30">Counterfeit V3.0</a></p>
<table>
<tr>
<td><img src="../animations/model_03/01.gif"></td>
<td><img src="../animations/model_03/02.gif"></td>
<td><img src="../animations/model_03/03.gif"></td>
<td><img src="../animations/model_03/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/4201/realistic-vision-v20">Realistic Vision V2.0</a></p>
<table>
<tr>
<td><img src="../animations/model_04/01.gif"></td>
<td><img src="../animations/model_04/02.gif"></td>
<td><img src="../animations/model_04/03.gif"></td>
<td><img src="../animations/model_04/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model <a href="https://civitai.com/models/43331/majicmix-realistic">majicMIX Realistic</a></p>
<table>
<tr>
<td><img src="../animations/model_05/01.gif"></td>
<td><img src="../animations/model_05/02.gif"></td>
<td><img src="../animations/model_05/03.gif"></td>
<td><img src="../animations/model_05/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/66347/rcnz-cartoon-3d">RCNZ Cartoon</a></p>
<table>
<tr>
<td><img src="../animations/model_06/01.gif"></td>
<td><img src="../animations/model_06/02.gif"></td>
<td><img src="../animations/model_06/03.gif"></td>
<td><img src="../animations/model_06/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">Model<a href="https://civitai.com/models/33208/filmgirl-film-grain-lora-and-loha">FilmVelvia</a></p>
#### Community Cases
Here are some samples contributed by the community artists. Create a Pull Request if you would like to show your results here😚.
<table>
<tr>
<td><img src="../animations/model_07/init.jpg"></td>
<td><img src="../animations/model_07/01.gif"></td>
<td><img src="../animations/model_07/02.gif"></td>
<td><img src="../animations/model_07/03.gif"></td>
<td><img src="../animations/model_07/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">
Character Model<a href="https://civitai.com/models/13237/genshen-impact-yoimiya">Yoimiya</a>
(with an initial reference image, see <a href="https://github.com/talesofai/AnimateDiff">WIP fork</a> for the extended implementation.)
<table>
<tr>
<td><img src="../animations/model_08/01.gif"></td>
<td><img src="../animations/model_08/02.gif"></td>
<td><img src="../animations/model_08/03.gif"></td>
<td><img src="../animations/model_08/04.gif"></td>
</tr>
</table>
<p style="margin-left: 2em; margin-top: -1em">
Character Model<a href="https://civitai.com/models/9850/paimon-genshin-impact">Paimon</a>;
Pose Model<a href="https://civitai.com/models/107295/or-holdingsign">Hold Sign</a></p>

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

BIN
animatediff/.DS_Store vendored Normal file

Binary file not shown.

View File

@@ -1,98 +0,0 @@
import os, io, csv, math, random
import numpy as np
from einops import rearrange
from decord import VideoReader
import torch
import torchvision.transforms as transforms
from torch.utils.data.dataset import Dataset
from animatediff.utils.util import zero_rank_print
class WebVid10M(Dataset):
def __init__(
self,
csv_path, video_folder,
sample_size=256, sample_stride=4, sample_n_frames=16,
is_image=False,
):
zero_rank_print(f"loading annotations from {csv_path} ...")
with open(csv_path, 'r') as csvfile:
self.dataset = list(csv.DictReader(csvfile))
self.length = len(self.dataset)
zero_rank_print(f"data scale: {self.length}")
self.video_folder = video_folder
self.sample_stride = sample_stride
self.sample_n_frames = sample_n_frames
self.is_image = is_image
sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
self.pixel_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.Resize(sample_size[0]),
transforms.CenterCrop(sample_size),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])
def get_batch(self, idx):
video_dict = self.dataset[idx]
videoid, name, page_dir = video_dict['videoid'], video_dict['name'], video_dict['page_dir']
video_dir = os.path.join(self.video_folder, f"{videoid}.mp4")
video_reader = VideoReader(video_dir)
video_length = len(video_reader)
if not self.is_image:
clip_length = min(video_length, (self.sample_n_frames - 1) * self.sample_stride + 1)
start_idx = random.randint(0, video_length - clip_length)
batch_index = np.linspace(start_idx, start_idx + clip_length - 1, self.sample_n_frames, dtype=int)
else:
batch_index = [random.randint(0, video_length - 1)]
pixel_values = torch.from_numpy(video_reader.get_batch(batch_index).asnumpy()).permute(0, 3, 1, 2).contiguous()
pixel_values = pixel_values / 255.
del video_reader
if self.is_image:
pixel_values = pixel_values[0]
return pixel_values, name
def __len__(self):
return self.length
def __getitem__(self, idx):
while True:
try:
pixel_values, name = self.get_batch(idx)
break
except Exception as e:
idx = random.randint(0, self.length-1)
pixel_values = self.pixel_transforms(pixel_values)
sample = dict(pixel_values=pixel_values, text=name)
return sample
if __name__ == "__main__":
from animatediff.utils.util import save_videos_grid
dataset = WebVid10M(
csv_path="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv",
video_folder="/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val",
sample_size=256,
sample_stride=4, sample_n_frames=16,
is_image=True,
)
import pdb
pdb.set_trace()
dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16,)
for idx, batch in enumerate(dataloader):
print(batch["pixel_values"].shape, len(batch["text"]))
# for i in range(batch["pixel_values"].shape[0]):
# save_videos_grid(batch["pixel_values"][i:i+1].permute(0,2,1,3,4), os.path.join(".", f"{idx}-{i}.mp4"), rescale=True)

View File

@@ -1,300 +0,0 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward, AdaLayerNorm
from einops import rearrange, repeat
import pdb
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention=None,
unet_use_temporal_attention=None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
unet_use_cross_frame_attention=unet_use_cross_frame_attention,
unet_use_temporal_attention=unet_use_temporal_attention,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length
)
# Output
if not self.use_linear_projection:
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
unet_use_cross_frame_attention = None,
unet_use_temporal_attention = None,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.unet_use_cross_frame_attention = unet_use_cross_frame_attention
self.unet_use_temporal_attention = unet_use_temporal_attention
# SC-Attn
assert unet_use_cross_frame_attention is not None
if unet_use_cross_frame_attention:
self.attn1 = SparseCausalAttention2D(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=cross_attention_dim if only_cross_attention else None,
upcast_attention=upcast_attention,
)
else:
self.attn1 = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# Cross-Attn
if cross_attention_dim is not None:
self.attn2 = CrossAttention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
else:
self.norm2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
# Temp-Attn
assert unet_use_temporal_attention is not None
if unet_use_temporal_attention:
self.attn_temp = CrossAttention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
print("Here is how to install it")
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
if self.attn2 is not None:
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
# self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
# SparseCausal-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
# if self.only_cross_attention:
# hidden_states = (
# self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
# )
# else:
# hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
# pdb.set_trace()
if self.unet_use_cross_frame_attention:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
else:
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None:
# Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
# Temporal-Attention
if self.unet_use_temporal_attention:
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
norm_hidden_states = (
self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
)
hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
from typing import Callable, List, Optional, Tuple, Union
import torch
import numpy as np
@@ -8,324 +8,418 @@ from torch import nn
import torchvision
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.modeling_utils import ModelMixin
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.utils.import_utils import is_xformers_available
from diffusers.models.attention import CrossAttention, FeedForward
from diffusers.models.attention_processor import Attention
from diffusers.models.attention import FeedForward
from animatediff.utils.util import zero_rank_print
from einops import rearrange, repeat
import math
import math, pdb
import random
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
if is_xformers_available():
import xformers
import xformers.ops
else:
xformers = None
sample: torch.FloatTensor
def get_motion_module(
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
in_channels,
motion_module_type: str,
motion_module_kwargs: dict
):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,)
else:
raise ValueError
if motion_module_type == "Vanilla":
return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
elif motion_module_type == "Conv":
return ConvTemporalModule(in_channels=in_channels, **motion_module_kwargs)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", "Temporal_Self" ),
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
temporal_attention_dim_div = 1,
zero_initialize = True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def __init__(
self,
in_channels,
num_attention_heads = 8,
num_transformer_block = 2,
attention_block_types =( "Temporal_Self", ),
spatial_position_encoding = False,
temporal_position_encoding = True,
temporal_position_encoding_max_len = 32,
temporal_attention_dim_div = 1,
zero_initialize = True,
causal_temporal_attention = False,
causal_temporal_attention_mask_type = "",
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
spatial_position_encoding = spatial_position_encoding,
causal_temporal_attention=causal_temporal_attention,
causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
def forward(self, input_tensor, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
spatial_position_encoding = False,
causal_temporal_attention = None,
causal_temporal_attention_mask_type = "",
):
super().__init__()
assert causal_temporal_attention is not None
self.causal_temporal_attention = causal_temporal_attention
num_layers,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
self.causal_temporal_attention_mask = None
self.spatial_position_encoding = spatial_position_encoding
inner_dim = num_attention_heads * attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
if spatial_position_encoding:
self.pos_encoder_2d = PositionalEncoding2D(inner_dim)
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def get_causal_temporal_attention_mask(self, hidden_states):
batch_size, sequence_length, dim = hidden_states.shape
if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (batch_size, sequence_length, sequence_length):
zero_rank_print(f"build attn mask of type {self.causal_temporal_attention_mask_type}")
if self.causal_temporal_attention_mask_type == "causal":
# 1. vanilla causal mask
mask = torch.tril(torch.ones(sequence_length, sequence_length))
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
elif self.causal_temporal_attention_mask_type == "2-seq":
# 2. 2-seq
mask = torch.zeros(sequence_length, sequence_length)
mask[:sequence_length // 2, :sequence_length // 2] = 1
mask[-sequence_length // 2:, -sequence_length // 2:] = 1
elif self.causal_temporal_attention_mask_type == "0-prev":
# attn to the previous frame
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
mask = torch.zeros(sequence_length, sequence_length)
mask[:, 0] = 1.
mask[indices, indices_prev] = 1.
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
elif self.causal_temporal_attention_mask_type == "0":
# only attn to first frame
mask = torch.zeros(sequence_length, sequence_length)
mask[:,0] = 1
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
elif self.causal_temporal_attention_mask_type == "wo-self":
indices = torch.arange(sequence_length)
mask = torch.ones(sequence_length, sequence_length)
mask[indices, indices] = 0
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
elif self.causal_temporal_attention_mask_type == "circle":
indices = torch.arange(sequence_length)
indices_prev = indices - 1
indices_prev[0] = 0
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
mask = torch.eye(sequence_length)
mask[indices, indices_prev] = 1
mask[0,-1] = 1
else: raise ValueError
# for sanity check
if dim == 320: zero_rank_print(mask)
# generate attention mask fron binary values
mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
mask = mask.unsqueeze(0)
mask = mask.repeat(batch_size, 1, 1)
self.causal_temporal_attention_mask = mask.to(hidden_states.device)
return self.causal_temporal_attention_mask
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
height, width = hidden_states.shape[-2:]
hidden_states = self.norm(hidden_states)
hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
hidden_states = self.proj_in(hidden_states)
if self.spatial_position_encoding:
video_length = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b h w) f c -> (b f) h w c", h=height, w=width)
pos_encoding = self.pos_encoder_2d(hidden_states)
pos_encoding = rearrange(pos_encoding, "(b f) h w c -> (b h w) f c", f = video_length)
hidden_states = rearrange(hidden_states, "(b f) h w c -> (b h w) f c", f=video_length)
attention_mask = self.get_causal_temporal_attention_mask(hidden_states) if self.causal_temporal_attention else attention_mask
# Transformer Blocks
for block in self.transformer_blocks:
if not self.spatial_position_encoding :
pos_encoding = None
hidden_states = block(hidden_states, pos_encoding=pos_encoding, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask)
hidden_states = self.proj_out(hidden_states)
hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
output = hidden_states + residual
# output = hidden_states
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
):
super().__init__()
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types = ( "Temporal_Self", "Temporal_Self", ),
dropout = 0.0,
norm_num_groups = 32,
cross_attention_dim = 768,
activation_fn = "geglu",
attention_bias = False,
upcast_attention = False,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
TemporalSelfAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
def forward(self, hidden_states, pos_encoding=None, encoder_hidden_states=None, attention_mask=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
if pos_encoding is not None:
hidden_states += pos_encoding
norm_hidden_states = norm(hidden_states)
hidden_states = attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
) + hidden_states
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
def get_emb(sin_inp):
"""
Gets a base embedding for one dimension with sin and cos intertwined
"""
emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1)
return torch.flatten(emb, -2, -1)
class PositionalEncoding2D(nn.Module):
def __init__(self, channels):
"""
:param channels: The last dimension of the tensor you want to apply pos emb to.
"""
super(PositionalEncoding2D, self).__init__()
self.org_channels = channels
channels = int(np.ceil(channels / 4) * 2)
self.channels = channels
inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels))
self.register_buffer("inv_freq", inv_freq)
self.register_buffer("cached_penc", None)
def forward(self, tensor):
"""
:param tensor: A 4d tensor of size (batch_size, x, y, ch)
:return: Positional Encoding Matrix of size (batch_size, x, y, ch)
"""
if len(tensor.shape) != 4:
raise RuntimeError("The input tensor has to be 4d!")
if self.cached_penc is not None and self.cached_penc.shape == tensor.shape:
return self.cached_penc
self.cached_penc = None
batch_size, x, y, orig_ch = tensor.shape
pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type())
pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type())
sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq)
sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq)
emb_x = get_emb(sin_inp_x).unsqueeze(1)
emb_y = get_emb(sin_inp_y)
emb = torch.zeros((x, y, self.channels * 2), device=tensor.device).type(
tensor.type()
)
emb[:, :, : self.channels] = emb_x
emb[:, :, self.channels : 2 * self.channels] = emb_y
self.cached_penc = emb[None, :, :, :orig_ch].repeat(tensor.shape[0], 1, 1, 1)
return self.cached_penc
class PositionalEncoding(nn.Module):
def __init__(
self,
d_model,
dropout = 0.,
max_len = 24
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def __init__(
self,
d_model,
dropout = 0.,
max_len = 32,
):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
def forward(self, x):
# if x.size(1) < 16:
# start_idx = random.randint(0, 12)
# else:
# start_idx = 0
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class VersatileAttention(CrossAttention):
def __init__(
self,
attention_mode = None,
cross_frame_attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 24,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
class TemporalSelfAttention(Attention):
def __init__(
self,
attention_mode = None,
temporal_position_encoding = False,
temporal_position_encoding_max_len = 32,
*args, **kwargs
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
dropout=0.,
max_len=temporal_position_encoding_max_len
) if (temporal_position_encoding and attention_mode == "Temporal") else None
self.pos_encoder = PositionalEncoding(
kwargs["query_dim"],
max_len=temporal_position_encoding_max_len
) if temporal_position_encoding else None
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
# disable motion module efficient xformers to avoid bad results, don't know why
# TODO: fix this bug
pass
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
batch_size, sequence_length, _ = hidden_states.shape
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
# The `Attention` class can call different attention processors / attention functions
# here we simply pass along all tensors to the selected processor class
# For standard processors that are defined here, `**cross_attention_kwargs` is empty
if self.attention_mode == "Temporal":
d = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
else:
raise NotImplementedError
# add position encoding
hidden_states = self.pos_encoder(hidden_states)
encoder_hidden_states = encoder_hidden_states
if hasattr(self.processor, "__call__"):
return self.processor.__call__(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
if self.added_kv_proj_dim is not None:
raise NotImplementedError
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
return hidden_states
else:
return self.processor(
self,
hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
**cross_attention_kwargs,
)

View File

@@ -1,217 +0,0 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
# if self.use_conv:
# if self.name == "conv":
# hidden_states = self.conv(hidden_states)
# else:
# hidden_states = self.Conv2d_0(hidden_states)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=None,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,959 +0,0 @@
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Conversion script for the Stable Diffusion checkpoints."""
import re
from io import BytesIO
from typing import Optional
import requests
import torch
from transformers import (
AutoFeatureExtractor,
BertTokenizerFast,
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionConfig,
CLIPVisionModelWithProjection,
)
from diffusers.models import (
AutoencoderKL,
PriorTransformer,
UNet2DConditionModel,
)
from diffusers.schedulers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
HeunDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
UnCLIPScheduler,
)
from diffusers.utils.import_utils import BACKENDS_MAPPING
def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
"""
if n_shave_prefix_segments >= 0:
return ".".join(path.split(".")[n_shave_prefix_segments:])
else:
return ".".join(path.split(".")[:n_shave_prefix_segments])
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item.replace("in_layers.0", "norm1")
new_item = new_item.replace("in_layers.2", "conv1")
new_item = new_item.replace("out_layers.0", "norm2")
new_item = new_item.replace("out_layers.3", "conv2")
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
new_item = new_item.replace("skip_connection", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside resnets to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
"""
Updates paths inside attentions to the new naming scheme (local renaming)
"""
mapping = []
for old_item in old_list:
new_item = old_item
new_item = new_item.replace("norm.weight", "group_norm.weight")
new_item = new_item.replace("norm.bias", "group_norm.bias")
new_item = new_item.replace("q.weight", "query.weight")
new_item = new_item.replace("q.bias", "query.bias")
new_item = new_item.replace("k.weight", "key.weight")
new_item = new_item.replace("k.bias", "key.bias")
new_item = new_item.replace("v.weight", "value.weight")
new_item = new_item.replace("v.bias", "value.bias")
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
mapping.append({"old": old_item, "new": new_item})
return mapping
def assign_to_checkpoint(
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
):
"""
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
attention layers, and takes into account additional replacements that may arise.
Assigns the weights to the new checkpoint.
"""
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
# Splits the attention layers into three variables.
if attention_paths_to_split is not None:
for path, path_map in attention_paths_to_split.items():
old_tensor = old_checkpoint[path]
channels = old_tensor.shape[0] // 3
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
query, key, value = old_tensor.split(channels // num_heads, dim=1)
checkpoint[path_map["query"]] = query.reshape(target_shape)
checkpoint[path_map["key"]] = key.reshape(target_shape)
checkpoint[path_map["value"]] = value.reshape(target_shape)
for path in paths:
new_path = path["new"]
# These have already been assigned
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
continue
# Global renaming happens here
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
if additional_replacements is not None:
for replacement in additional_replacements:
new_path = new_path.replace(replacement["old"], replacement["new"])
# proj_attn.weight has to be converted from conv 1D to linear
if "proj_attn.weight" in new_path:
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
else:
checkpoint[new_path] = old_checkpoint[path["old"]]
def conv_attn_to_linear(checkpoint):
keys = list(checkpoint.keys())
attn_keys = ["query.weight", "key.weight", "value.weight"]
for key in keys:
if ".".join(key.split(".")[-2:]) in attn_keys:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0, 0]
elif "proj_attn.weight" in key:
if checkpoint[key].ndim > 2:
checkpoint[key] = checkpoint[key][:, :, 0]
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
if controlnet:
unet_params = original_config.model.params.control_stage_config.params
else:
unet_params = original_config.model.params.unet_config.params
vae_params = original_config.model.params.first_stage_config.params.ddconfig
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
down_block_types = []
resolution = 1
for i in range(len(block_out_channels)):
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
down_block_types.append(block_type)
if i != len(block_out_channels) - 1:
resolution *= 2
up_block_types = []
for i in range(len(block_out_channels)):
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
up_block_types.append(block_type)
resolution //= 2
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
use_linear_projection = (
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
)
if use_linear_projection:
# stable diffusion 2-base-512 and 2-768
if head_dim is None:
head_dim = [5, 10, 20, 20]
class_embed_type = None
projection_class_embeddings_input_dim = None
if "num_classes" in unet_params:
if unet_params.num_classes == "sequential":
class_embed_type = "projection"
assert "adm_in_channels" in unet_params
projection_class_embeddings_input_dim = unet_params.adm_in_channels
else:
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
config = {
"sample_size": image_size // vae_scale_factor,
"in_channels": unet_params.in_channels,
"down_block_types": tuple(down_block_types),
"block_out_channels": tuple(block_out_channels),
"layers_per_block": unet_params.num_res_blocks,
"cross_attention_dim": unet_params.context_dim,
"attention_head_dim": head_dim,
"use_linear_projection": use_linear_projection,
"class_embed_type": class_embed_type,
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
}
if not controlnet:
config["out_channels"] = unet_params.out_channels
config["up_block_types"] = tuple(up_block_types)
return config
def create_vae_diffusers_config(original_config, image_size: int):
"""
Creates a config for the diffusers based on the config of the LDM model.
"""
vae_params = original_config.model.params.first_stage_config.params.ddconfig
_ = original_config.model.params.first_stage_config.params.embed_dim
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
config = {
"sample_size": image_size,
"in_channels": vae_params.in_channels,
"out_channels": vae_params.out_ch,
"down_block_types": tuple(down_block_types),
"up_block_types": tuple(up_block_types),
"block_out_channels": tuple(block_out_channels),
"latent_channels": vae_params.z_channels,
"layers_per_block": vae_params.num_res_blocks,
}
return config
def create_diffusers_schedular(original_config):
schedular = DDIMScheduler(
num_train_timesteps=original_config.model.params.timesteps,
beta_start=original_config.model.params.linear_start,
beta_end=original_config.model.params.linear_end,
beta_schedule="scaled_linear",
)
return schedular
def create_ldm_bert_config(original_config):
bert_params = original_config.model.parms.cond_stage_config.params
config = LDMBertConfig(
d_model=bert_params.n_embed,
encoder_layers=bert_params.n_layer,
encoder_ffn_dim=bert_params.n_embed * 4,
)
return config
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
"""
Takes a state dict and a config, and returns a converted checkpoint.
"""
# extract state_dict for UNet
unet_state_dict = {}
keys = list(checkpoint.keys())
if controlnet:
unet_key = "control_model."
else:
unet_key = "model.diffusion_model."
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
print(
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
)
for key in keys:
if key.startswith("model.diffusion_model"):
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
else:
if sum(k.startswith("model_ema") for k in keys) > 100:
print(
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
)
for key in keys:
if key.startswith(unet_key):
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
new_checkpoint = {}
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
if config["class_embed_type"] is None:
# No parameters to port
...
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
else:
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
if not controlnet:
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
# Retrieves the keys for the input blocks only
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
input_blocks = {
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
for layer_id in range(num_input_blocks)
}
# Retrieves the keys for the middle blocks only
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
middle_blocks = {
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
for layer_id in range(num_middle_blocks)
}
# Retrieves the keys for the output blocks only
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
output_blocks = {
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
for layer_id in range(num_output_blocks)
}
for i in range(1, num_input_blocks):
block_id = (i - 1) // (config["layers_per_block"] + 1)
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
resnets = [
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
]
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.weight"
)
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
f"input_blocks.{i}.0.op.bias"
)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
resnet_0 = middle_blocks[0]
attentions = middle_blocks[1]
resnet_1 = middle_blocks[2]
resnet_0_paths = renew_resnet_paths(resnet_0)
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
resnet_1_paths = renew_resnet_paths(resnet_1)
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
attentions_paths = renew_attention_paths(attentions)
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
for i in range(num_output_blocks):
block_id = i // (config["layers_per_block"] + 1)
layer_in_block_id = i % (config["layers_per_block"] + 1)
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
output_block_list = {}
for layer in output_block_layers:
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
if layer_id in output_block_list:
output_block_list[layer_id].append(layer_name)
else:
output_block_list[layer_id] = [layer_name]
if len(output_block_list) > 1:
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
resnet_0_paths = renew_resnet_paths(resnets)
paths = renew_resnet_paths(resnets)
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
if ["conv.bias", "conv.weight"] in output_block_list.values():
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.weight"
]
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
f"output_blocks.{i}.{index}.conv.bias"
]
# Clear attentions as they have been attributed above.
if len(attentions) == 2:
attentions = []
if len(attentions):
paths = renew_attention_paths(attentions)
meta_path = {
"old": f"output_blocks.{i}.1",
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
}
assign_to_checkpoint(
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
)
else:
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
for path in resnet_0_paths:
old_path = ".".join(["output_blocks", str(i), path["old"]])
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
new_checkpoint[new_path] = unet_state_dict[old_path]
if controlnet:
# conditioning embedding
orig_index = 0
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
orig_index += 2
diffusers_index = 0
while diffusers_index < 6:
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
diffusers_index += 1
orig_index += 2
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.weight"
)
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
f"input_hint_block.{orig_index}.bias"
)
# down blocks
for i in range(num_input_blocks):
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
# mid block
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
return new_checkpoint
def convert_ldm_vae_checkpoint(checkpoint, config):
# extract state dict for VAE
vae_state_dict = {}
vae_key = "first_stage_model."
keys = list(checkpoint.keys())
for key in keys:
if key.startswith(vae_key):
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
new_checkpoint = {}
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
# Retrieves the keys for the encoder down blocks only
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
down_blocks = {
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
}
# Retrieves the keys for the decoder up blocks only
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
up_blocks = {
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
}
for i in range(num_down_blocks):
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.weight"
)
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
f"encoder.down.{i}.downsample.conv.bias"
)
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
for i in range(num_up_blocks):
block_id = num_up_blocks - 1 - i
resnets = [
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
]
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.weight"
]
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
f"decoder.up.{block_id}.upsample.conv.bias"
]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
num_mid_res_blocks = 2
for i in range(1, num_mid_res_blocks + 1):
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
paths = renew_vae_resnet_paths(resnets)
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
paths = renew_vae_attention_paths(mid_attentions)
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
conv_attn_to_linear(new_checkpoint)
return new_checkpoint
def convert_ldm_bert_checkpoint(checkpoint, config):
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
def _copy_linear(hf_linear, pt_linear):
hf_linear.weight = pt_linear.weight
hf_linear.bias = pt_linear.bias
def _copy_layer(hf_layer, pt_layer):
# copy layer norms
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
# copy attn
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
# copy MLP
pt_mlp = pt_layer[1][1]
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
def _copy_layers(hf_layers, pt_layers):
for i, hf_layer in enumerate(hf_layers):
if i != 0:
i += i
pt_layer = pt_layers[i : i + 2]
_copy_layer(hf_layer, pt_layer)
hf_model = LDMBertModel(config).eval()
# copy embeds
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
# copy layer norm
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
# copy hidden layers
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
return hf_model
def convert_ldm_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
return text_model
textenc_conversion_lst = [
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
]
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
textenc_transformer_conversion_lst = [
# (stable-diffusion, HF Diffusers)
("resblocks.", "text_model.encoder.layers."),
("ln_1", "layer_norm1"),
("ln_2", "layer_norm2"),
(".c_fc.", ".fc1."),
(".c_proj.", ".fc2."),
(".attn", ".self_attn"),
("ln_final.", "transformer.text_model.final_layer_norm."),
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
]
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
textenc_pattern = re.compile("|".join(protected.keys()))
def convert_paint_by_example_checkpoint(checkpoint):
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
model = PaintByExampleImageEncoder(config)
keys = list(checkpoint.keys())
text_model_dict = {}
for key in keys:
if key.startswith("cond_stage_model.transformer"):
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
# load clip vision
model.model.load_state_dict(text_model_dict)
# load mapper
keys_mapper = {
k[len("cond_stage_model.mapper.res") :]: v
for k, v in checkpoint.items()
if k.startswith("cond_stage_model.mapper")
}
MAPPING = {
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
"attn.c_proj": ["attn1.to_out.0"],
"ln_1": ["norm1"],
"ln_2": ["norm3"],
"mlp.c_fc": ["ff.net.0.proj"],
"mlp.c_proj": ["ff.net.2"],
}
mapped_weights = {}
for key, value in keys_mapper.items():
prefix = key[: len("blocks.i")]
suffix = key.split(prefix)[-1].split(".")[-1]
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
mapped_names = MAPPING[name]
num_splits = len(mapped_names)
for i, mapped_name in enumerate(mapped_names):
new_name = ".".join([prefix, mapped_name, suffix])
shape = value.shape[0] // num_splits
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
model.mapper.load_state_dict(mapped_weights)
# load final layer norm
model.final_layer_norm.load_state_dict(
{
"bias": checkpoint["cond_stage_model.final_ln.bias"],
"weight": checkpoint["cond_stage_model.final_ln.weight"],
}
)
# load final proj
model.proj_out.load_state_dict(
{
"bias": checkpoint["proj_out.bias"],
"weight": checkpoint["proj_out.weight"],
}
)
# load uncond vector
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
return model
def convert_open_clip_checkpoint(checkpoint):
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
keys = list(checkpoint.keys())
text_model_dict = {}
if "cond_stage_model.model.text_projection" in checkpoint:
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
else:
d_model = 1024
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
for key in keys:
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
continue
if key in textenc_conversion_map:
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
if key.startswith("cond_stage_model.model.transformer."):
new_key = key[len("cond_stage_model.model.transformer.") :]
if new_key.endswith(".in_proj_weight"):
new_key = new_key[: -len(".in_proj_weight")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
elif new_key.endswith(".in_proj_bias"):
new_key = new_key[: -len(".in_proj_bias")]
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
else:
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
text_model_dict[new_key] = checkpoint[key]
text_model.load_state_dict(text_model_dict)
return text_model
def stable_unclip_image_encoder(original_config):
"""
Returns the image processor and clip image encoder for the img2img unclip pipeline.
We currently know of two types of stable unclip models which separately use the clip and the openclip image
encoders.
"""
image_embedder_config = original_config.model.params.embedder_config
sd_clip_image_embedder_class = image_embedder_config.target
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
if sd_clip_image_embedder_class == "ClipImageEmbedder":
clip_model_name = image_embedder_config.params.model
if clip_model_name == "ViT-L/14":
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
else:
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
feature_extractor = CLIPImageProcessor()
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
else:
raise NotImplementedError(
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
)
return feature_extractor, image_encoder
def stable_unclip_image_noising_components(
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
):
"""
Returns the noising components for the img2img and txt2img unclip pipelines.
Converts the stability noise augmentor into
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
2. a `DDPMScheduler` for holding the noise schedule
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
"""
noise_aug_config = original_config.model.params.noise_aug_config
noise_aug_class = noise_aug_config.target
noise_aug_class = noise_aug_class.split(".")[-1]
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
noise_aug_config = noise_aug_config.params
embedding_dim = noise_aug_config.timestep_dim
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
if "clip_stats_path" in noise_aug_config:
if clip_stats_path is None:
raise ValueError("This stable unclip config requires a `clip_stats_path`")
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
clip_mean = clip_mean[None, :]
clip_std = clip_std[None, :]
clip_stats_state_dict = {
"mean": clip_mean,
"std": clip_std,
}
image_normalizer.load_state_dict(clip_stats_state_dict)
else:
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
return image_normalizer, image_noising_scheduler
def convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
):
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
ctrlnet_config["upcast_attention"] = upcast_attention
ctrlnet_config.pop("sample_size")
controlnet_model = ControlNetModel(**ctrlnet_config)
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
)
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
return controlnet_model

View File

@@ -5,7 +5,7 @@
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
@@ -23,132 +23,112 @@ from safetensors.torch import load_file
from diffusers import StableDiffusionPipeline
import pdb
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
# directly update weight in diffusers model
for key in state_dict:
# only process lora down key
if "up." in key: continue
up_key = key.replace(".down.", ".up.")
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
model_key = model_key.replace("to_out.", "to_out.0.")
layer_infos = model_key.split(".")[:-1]
curr_layer = pipeline.unet
while len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
curr_layer = curr_layer.__getattr__(temp_name)
weight_down = state_dict[key]
weight_up = state_dict[up_key]
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
return pipeline
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
# load base model
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
# load base model
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
# load LoRA weight from .safetensors
# state_dict = load_file(checkpoint_path)
# load LoRA weight from .safetensors
# state_dict = load_file(checkpoint_path)
visited = []
visited = []
# directly update weight in diffusers model
for lora_name in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# directly update weight in diffusers model
for key in state_dict:
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if ".alpha" in lora_name or lora_name in visited:
continue
# as we have set the alpha beforehand, so just skip
if ".alpha" in key or key in visited:
continue
if "te" in lora_name:
if "lora_te1" in key:
LORA_PREFIX_TEXT_ENCODER = "lora_te1"
elif "lora_te2" in key:
LORA_PREFIX_TEXT_ENCODER = "lora_te2"
else:
LORA_PREFIX_TEXT_ENCODER = "lora_te"
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
if "text" in key:
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet
else:
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
pair_keys = []
if "lora_down" in key:
pair_keys.append(key.replace("lora_down", "lora_up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora_up", "lora_down"))
pair_keys = []
if "lora.down" in key:
pair_keys.append(key.replace("lora.down", "lora.up"))
pair_keys.append(key)
else:
pair_keys.append(key)
pair_keys.append(key.replace("lora.up", "lora.down"))
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
# update visited list
for item in pair_keys:
visited.append(item)
# update visited list
for item in pair_keys:
visited.append(item)
return pipeline
return pipeline
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument(
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
)
parser.add_argument(
"--lora_prefix_text_encoder",
default="lora_te",
type=str,
help="The prefix of text encoder weight in safetensors",
)
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
parser.add_argument(
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
)
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
parser.add_argument(
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
)
parser.add_argument(
"--lora_prefix_text_encoder",
default="lora_te",
type=str,
help="The prefix of text encoder weight in safetensors",
)
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
parser.add_argument(
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
)
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
args = parser.parse_args()
args = parser.parse_args()
base_model_path = args.base_model_path
checkpoint_path = args.checkpoint_path
dump_path = args.dump_path
lora_prefix_unet = args.lora_prefix_unet
lora_prefix_text_encoder = args.lora_prefix_text_encoder
alpha = args.alpha
base_model_path = args.base_model_path
checkpoint_path = args.checkpoint_path
dump_path = args.dump_path
lora_prefix_unet = args.lora_prefix_unet
lora_prefix_text_encoder = args.lora_prefix_text_encoder
alpha = args.alpha
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
pipe = pipe.to(args.device)
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
pipe = pipe.to(args.device)
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,27 +0,0 @@
unet_additional_kwargs:
use_inflated_groupnorm: true
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
motion_module_resolutions:
- 1
- 2
- 4
- 8
motion_module_mid_block: true
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"

View File

@@ -1,6 +1,4 @@
unet_additional_kwargs:
unet_use_cross_frame_attention: false
unet_use_temporal_attention: false
use_motion_module: true
motion_module_resolutions:
- 1
@@ -8,7 +6,6 @@ unet_additional_kwargs:
- 4
- 8
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
@@ -17,10 +14,14 @@ unet_additional_kwargs:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_position_encoding_max_len: 32
temporal_attention_dim_div: 1
noise_scheduler_kwargs:
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.020
beta_schedule: "scaled_linear"

View File

@@ -1,23 +0,0 @@
ToonYou:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/toonyou_beta3.safetensors"
lora_model_path: ""
seed: [10788741199826055526, 6520604954829636163, 6519455744612555650, 16372571278361863751]
steps: 25
guidance_scale: 7.5
prompt:
- "best quality, masterpiece, 1girl, looking at viewer, blurry background, upper body, contemporary, dress"
- "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes,"
- "best quality, masterpiece, 1boy, formal, abstract, looking at viewer, masculine, marble pattern"
- "best quality, masterpiece, 1girl, cloudy sky, dandelion, contrapposto, alternate hairstyle,"
n_prompt:
- ""
- "badhandv4,easynegative,ng_deepnegative_v1_75t,verybadimagenegative_v1.3, bad-artist, bad_prompt_version2-neg, teeth"
- ""
- ""

View File

@@ -0,0 +1,14 @@
motion_module_path: Path to Motion Module
seed: -1
guidance_scale: 8.5
steps: 100
prompt:
- "A panda standing on a surfboard in the ocean in sunset, 4k, high resolution.Realistic, Cinematic, high resolution"
- "A disoriented astronaut, lost in a galaxy of swirling colors, floating in zero gravity, grasping at memories, poignant loneliness, stunning realism, cosmic chaos, emotional depth, 12K, hyperrealism, unforgettable, mixed media, celestial, dark, introspective"
n_prompt:
- ""
- ""

View File

@@ -0,0 +1,17 @@
base_model_path: Path to DynaVision
motion_module_path: Path to Motion Module
seed: -1
guidance_scale: 8.5
steps: 100
prompt:
- "cinematic color grading lighting vintage realistic film grain scratches celluloid analog cool shadows warm highlights soft focus actor directed cinematography technicolor , confused, looking around scared , + / A lanky blonde haired man on vacation enjoying the local party scene in Brisbane at dawn"
- "Pixel Art pixelated pixel pixel , extremely content happy smile , + / A fit pink haired woman at night in the city, holding a trick or treat bag, wearing a kitty ears halloween costume, (big eyes:1.3) by Dreamworks Studios"
- "anime artwork of [water|a galaxy|thunderstorm] inside a bottle"
n_prompt:
- ""
- ""
- "hand"

View File

@@ -1,23 +0,0 @@
Lyriel:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/lyriel_v16.safetensors"
lora_model_path: ""
seed: [10917152860782582783, 6399018107401806238, 15875751942533906793, 6653196880059936551]
steps: 25
guidance_scale: 7.5
prompt:
- "dark shot, epic realistic, portrait of halo, sunglasses, blue eyes, tartan scarf, white hair by atey ghailan, by greg rutkowski, by greg tocchini, by james gilleard, by joe fenton, by kaethe butcher, gradient yellow, black, brown and magenta color scheme, grunge aesthetic!!! graffiti tag wall background, art by greg rutkowski and artgerm, soft cinematic light, adobe lightroom, photolab, hdr, intricate, highly detailed, depth of field, faded, neutral colors, hdr, muted colors, hyperdetailed, artstation, cinematic, warm lights, dramatic light, intricate details, complex background, rutkowski, teal and orange"
- "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal"
- "dark theme, medieval portrait of a man sharp features, grim, cold stare, dark colors, Volumetric lighting, baroque oil painting by Greg Rutkowski, Artgerm, WLOP, Alphonse Mucha dynamic lighting hyperdetailed intricately detailed, hdr, muted colors, complex background, hyperrealism, hyperdetailed, amandine van ray"
- "As I have gone alone in there and with my treasures bold, I can keep my secret where and hint of riches new and old. Begin it where warm waters halt and take it in a canyon down, not far but too far to walk, put in below the home of brown."
n_prompt:
- "3d, cartoon, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, young, loli, elf, 3d, illustration"
- "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular"
- "dof, grayscale, black and white, bw, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular,badhandsv5-neg, By bad artist -neg 1, monochrome"
- "holding an item, cowboy, hat, cartoon, 3d, disfigured, bad art, deformed,extra limbs,close up,b&w, wierd colors, blurry, duplicate, morbid, mutilated, [out of frame], extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, ugly, blurry, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, out of frame, ugly, extra limbs, bad anatomy, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, mutated hands, fused fingers, too many fingers, long neck, Photoshop, video game, ugly, tiling, poorly drawn hands, poorly drawn feet, poorly drawn face, out of frame, mutation, mutated, extra limbs, extra legs, extra arms, disfigured, deformed, cross-eye, body out of frame, blurry, bad art, bad anatomy, 3d render"

View File

@@ -0,0 +1,16 @@
ckpt_path: Path to DreamShaper
motion_module_path: Path to Motion Module
seed: -1
guidance_scale: 10
steps: 100
prompt:
- "photo of a supercar, 8k uhd, high quality, road, sunset, motion blur, depth blur, cinematic, filmic image 4k, 8k. Natural sunlight, vibrant color, reflections"
- "night view, sea beach landscape, with a lighthouse, trending on artstation, by Noah Bradley, highly detailed, high quality, 4k HDR, path tracking, calm landscape, high consistency, soft lighting, <lora:xl_more_art-full_v1:0.3> in Serene grids, transcendent fields, quiet repetitions, sublime reduction in Brooding landscapes, epic scale, German myth, layered symbolic density"
n_prompt:
- "embedding:BadDream, embedding:UnrealisticDream"
- "(worst quality), (low quality), (normal quality), lowres, blurry, hand, signature, normal quality"

View File

@@ -1,23 +0,0 @@
RcnzCartoon:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/rcnzCartoon3d_v10.safetensors"
lora_model_path: ""
seed: [16931037867122267877, 2094308009433392066, 4292543217695451092, 15572665120852309890]
steps: 25
guidance_scale: 7.5
prompt:
- "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded"
- "close up Portrait photo of muscular bearded guy in a worn mech suit, light bokeh, intricate, steel metal [rust], elegant, sharp focus, photo by greg rutkowski, soft lighting, vibrant colors, masterpiece, streets, detailed face"
- "absurdres, photorealistic, masterpiece, a 30 year old man with gold framed, aviator reading glasses and a black hooded jacket and a beard, professional photo, a character portrait, altermodern, detailed eyes, detailed lips, detailed face, grey eyes"
- "a golden labrador, warm vibrant colours, natural lighting, dappled lighting, diffused lighting, absurdres, highres,k, uhd, hdr, rtx, unreal, octane render, RAW photo, photorealistic, global illumination, subsurface scattering"
n_prompt:
- "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
- "nude, cross eyed, tongue, open mouth, inside, 3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, red eyes, muscular"
- "easynegative, cartoon, anime, sketches, necklace, earrings worst quality, low quality, normal quality, bad anatomy, bad hands, shiny skin, error, missing fingers, extra digit, fewer digits, jpeg artifacts, signature, watermark, username, blurry, chubby, anorectic, bad eyes, old, wrinkled skin, red skin, photograph By bad artist -neg, big eyes, muscular face,"
- "beard, EasyNegative, lowres, chromatic aberration, depth of field, motion blur, blurry, bokeh, bad quality, worst quality, multiple arms, badhand"

View File

@@ -0,0 +1,20 @@
ckpt_path: Path to DeepBlud model
motion_module_path: Path to Motion Module
seed: -1
guidance_scale: 7
steps: 100
prompt:
- "highres,best quality,natural, A photo of a mummified Miffy, the iconic white bunny character, lying peacefully on a velvet cushion inside an intricately decorated glass display case, with its fur preserved in a pristine condition, its eyes closed peacefully, and its little paws gently crossed over each other, evoking a sense of tranquility and reverence.,cinematic photo official art, 8k wallpaper,ultra detailed, aesthetic quality,photorealistic,entangle,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic,Perspective"
- "highres,best quality,natural, Cat pressing the Delete key on the keyboard,cinematic photo official art, 8k wallpaper,ultra detailed, aesthetic quality,photorealistic,entangle,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic,Perspective"
- "highres,best quality,natural, A psychedelic world spreads out below, Witch kawaii girl floating in the sky Wearing a witch hat Frilled fantasy costume Blonde straight hair,cinematic photo official art,unity 8k wallpaper,ultra detailed,aesthetic,masterpiece,best quality,photorealistic,entangle,mandala,tangle,entangle,cstasy of flower,dynamic angle,the most beautiful form of chaos,elegant,a brutalist designed,vivid colours,romanticism,atmospheric . 35mm photograph, film, bokeh, professional, 4k, highly detailed, skin detail realistic, ultra realistic,"
n_prompt:
- "bad quality,worst quality"
- "bad quality,worst quality"
- "bad quality,worst quality"

View File

@@ -1,23 +0,0 @@
MajicMix:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v5Preview.safetensors"
lora_model_path: ""
seed: [1572448948722921032, 1099474677988590681, 6488833139725635347, 18339859844376517918]
steps: 25
guidance_scale: 7.5
prompt:
- "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic"
- "best quality, masterpiece, photorealistic, 1boy, 50 years old beard, dramatic lighting"
- "best quality, masterpiece, photorealistic, 1girl, light smile, shirt with collars, waist up, dramatic lighting, from below"
- "male, man, beard, bodybuilder, skinhead,cold face, tough guy, cowboyshot, tattoo, french windows, luxury hotel masterpiece, best quality, photorealistic"
n_prompt:
- "ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles"
- "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
- "nsfw, ng_deepnegative_v1_75t,badhandv4, worst quality, low quality, normal quality, lowres,watermark, monochrome"
- "nude, nsfw, ng_deepnegative_v1_75t, badhandv4, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, monochrome, grayscale watermark, moles, people"

View File

@@ -1,23 +0,0 @@
RealisticVision:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: [5658137986800322009, 12099779162349365895, 10499524853910852697, 16768009035333711932]
steps: 25
guidance_scale: 7.5
prompt:
- "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
- "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
- "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
n_prompt:
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

View File

@@ -1,21 +0,0 @@
Tusun:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/moonfilm_reality20.safetensors"
lora_model_path: "models/DreamBooth_LoRA/TUSUN.safetensors"
lora_alpha: 0.6
seed: [10154078483724687116, 2664393535095473805, 4231566096207622938, 1713349740448094493]
steps: 25
guidance_scale: 7.5
prompt:
- "tusuncub with its mouth open, blurry, open mouth, fangs, photo background, looking at viewer, tongue, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
- "cute tusun with a blurry background, black background, simple background, signature, face, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
- "cut tusuncub walking in the snow, blurry, looking at viewer, depth of field, blurry background, full body, solo, cute and lovely, Beautiful and realistic eye details, perfect anatomy, Nonsense, pure background, Centered-Shot, realistic photo, photograph, 4k, hyper detailed, DSLR, 24 Megapixels, 8mm Lens, Full Frame, film grain, Global Illumination, studio Lighting, Award Winning Photography, diffuse reflection, ray tracing"
- "character design, cyberpunk tusun kitten wearing astronaut suit, sci-fic, realistic eye color and details, fluffy, big head, science fiction, communist ideology, Cyborg, fantasy, intense angle, soft lighting, photograph, 4k, hyper detailed, portrait wallpaper, realistic, photo-realistic, DSLR, 24 Megapixels, Full Frame, vibrant details, octane render, finely detail, best quality, incredibly absurdres, robotic parts, rim light, vibrant details, luxurious cyberpunk, hyperrealistic, cable electric wires, microchip, full body"
n_prompt:
- "worst quality, low quality, deformed, distorted, disfigured, bad eyes, bad anatomy, disconnected limbs, wrong body proportions, low quality, worst quality, text, watermark, signatre, logo, illustration, painting, cartoons, ugly, easy_negative"

View File

@@ -1,24 +0,0 @@
FilmVelvia:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/majicmixRealistic_v4.safetensors"
lora_model_path: "models/DreamBooth_LoRA/FilmVelvia2.safetensors"
lora_alpha: 0.6
seed: [358675358833372813, 3519455280971923743, 11684545350557985081, 8696855302100399877]
steps: 25
guidance_scale: 7.5
prompt:
- "a woman standing on the side of a road at night,girl, long hair, motor vehicle, car, looking at viewer, ground vehicle, night, hands in pockets, blurry background, coat, black hair, parted lips, bokeh, jacket, brown hair, outdoors, red lips, upper body, artist name"
- ", dark shot,0mm, portrait quality of a arab man worker,boy, wasteland that stands out vividly against the background of the desert, barren landscape, closeup, moles skin, soft light, sharp, exposure blend, medium shot, bokeh, hdr, high contrast, cinematic, teal and orange5, muted colors, dim colors, soothing tones, low saturation, hyperdetailed, noir"
- "fashion photography portrait of 1girl, offshoulder, fluffy short hair, soft light, rim light, beautiful shadow, low key, photorealistic, raw photo, natural skin texture, realistic eye and face details, hyperrealism, ultra high res, 4K, Best quality, masterpiece, necklace, cleavage, in the dark"
- "In this lighthearted portrait, a woman is dressed as a fierce warrior, armed with an arsenal of paintbrushes and palette knives. Her war paint is composed of thick, vibrant strokes of color, and her armor is made of paint tubes and paint-splattered canvases. She stands victoriously atop a mountain of conquered blank canvases, with a beautiful, colorful landscape behind her, symbolizing the power of art and creativity. bust Portrait, close-up, Bright and transparent scene lighting, "
n_prompt:
- "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
- "cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
- "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"
- "wrong white balance, dark, cartoon, anime, sketches,worst quality, low quality, deformed, distorted, disfigured, bad eyes, wrong lips, weird mouth, bad teeth, mutated hands and fingers, bad anatomy, wrong anatomy, amputation, extra limb, missing limb, floating limbs, disconnected limbs, mutation, ugly, disgusting, bad_pictures, negative_hand-neg"

View File

@@ -1,21 +0,0 @@
GhibliBackground:
motion_module:
- "models/Motion_Module/mm_sd_v14.ckpt"
- "models/Motion_Module/mm_sd_v15.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/CounterfeitV30_25.safetensors"
lora_model_path: "models/DreamBooth_LoRA/lora_Ghibli_n3.safetensors"
lora_alpha: 1.0
seed: [8775748474469046618, 5893874876080607656, 11911465742147695752, 12437784838692000640]
steps: 25
guidance_scale: 7.5
prompt:
- "best quality,single build,architecture, blue_sky, building,cloudy_sky, day, fantasy, fence, field, house, build,architecture,landscape, moss, outdoors, overgrown, path, river, road, rock, scenery, sky, sword, tower, tree, waterfall"
- "black_border, building, city, day, fantasy, ice, landscape, letterboxed, mountain, ocean, outdoors, planet, scenery, ship, snow, snowing, water, watercraft, waterfall, winter"
- ",mysterious sea area, fantasy,build,concept"
- "Tomb Raider,Scenography,Old building"
n_prompt:
- "easynegative,bad_construction,bad_structure,bad_wail,bad_windows,blurry,cloned_window,cropped,deformed,disfigured,error,extra_windows,extra_chimney,extra_door,extra_structure,extra_frame,fewer_digits,fused_structure,gross_proportions,jpeg_artifacts,long_roof,low_quality,structure_limbs,missing_windows,missing_doors,missing_roofs,mutated_structure,mutation,normal_quality,out_of_frame,owres,poorly_drawn_structure,poorly_drawn_house,signature,text,too_many_windows,ugly,username,uta,watermark,worst_quality"

View File

@@ -1,189 +0,0 @@
ZoomIn:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_ZoomIn.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
ZoomOut:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_ZoomOut.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
PanLeft:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_PanLeft.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
PanRight:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_PanRight.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
TiltUp:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_TiltUp.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
TiltDown:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_TiltDown.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
RollingAnticlockwise:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_RollingAnticlockwise.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
RollingClockwise:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
motion_module_lora_configs:
- path: "models/MotionLoRA/v2_lora_RollingClockwise.ckpt"
alpha: 1.0
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: 45987230
steps: 25
guidance_scale: 7.5
prompt:
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
n_prompt:
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

View File

@@ -1,23 +0,0 @@
RealisticVision:
inference_config: "configs/inference/inference-v2.yaml"
motion_module:
- "models/Motion_Module/mm_sd_v15_v2.ckpt"
dreambooth_path: "models/DreamBooth_LoRA/realisticVisionV20_v20.safetensors"
lora_model_path: ""
seed: [13100322578370451493, 14752961627088720670, 9329399085567825781, 16987697414827649302]
steps: 25
guidance_scale: 7.5
prompt:
- "b&w photo of 42 y.o man in black clothes, bald, face, half body, body, high detailed skin, skin pores, coastline, overcast weather, wind, waves, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
- "close up photo of a rabbit, forest, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot"
- "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3"
- "night, b&w photo of old house, post apocalypse, forest, storm weather, wind, rocks, 8k uhd, dslr, soft lighting, high quality, film grain"
n_prompt:
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
- "semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"
- "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, art, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation"

View File

@@ -1,48 +0,0 @@
image_finetune: true
output_dir: "outputs"
pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "scaled_linear"
steps_offset: 1
clip_sample: false
train_data:
csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
sample_size: 256
validation_data:
prompts:
- "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
- "A drone view of celebration with Christma tree and fireworks, starry sky - background."
- "Robot dancing in times square."
- "Pacific coast, carmel by the sea ocean and waves."
num_inference_steps: 25
guidance_scale: 8.
trainable_modules:
- "."
unet_checkpoint_path: ""
learning_rate: 1.e-5
train_batch_size: 50
max_train_epoch: -1
max_train_steps: 100
checkpointing_epochs: -1
checkpointing_steps: 60
validation_steps: 5000
validation_steps_tuple: [2, 50]
global_seed: 42
mixed_precision_training: true
enable_xformers_memory_efficient_attention: True
is_debug: False

View File

@@ -1,66 +0,0 @@
image_finetune: false
output_dir: "outputs"
pretrained_model_path: "models/StableDiffusion/stable-diffusion-v1-5"
unet_additional_kwargs:
use_motion_module : true
motion_module_resolutions : [ 1,2,4,8 ]
unet_use_cross_frame_attention : false
unet_use_temporal_attention : false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads : 8
num_transformer_block : 1
attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
temporal_position_encoding : true
temporal_position_encoding_max_len : 24
temporal_attention_dim_div : 1
zero_initialize : true
noise_scheduler_kwargs:
num_train_timesteps: 1000
beta_start: 0.00085
beta_end: 0.012
beta_schedule: "linear"
steps_offset: 1
clip_sample: false
train_data:
csv_path: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/results_2M_val.csv"
video_folder: "/mnt/petrelfs/guoyuwei/projects/datasets/webvid/2M_val"
sample_size: 256
sample_stride: 4
sample_n_frames: 16
validation_data:
prompts:
- "Snow rocky mountains peaks canyon. Snow blanketed rocky mountains surround and shadow deep canyons."
- "A drone view of celebration with Christma tree and fireworks, starry sky - background."
- "Robot dancing in times square."
- "Pacific coast, carmel by the sea ocean and waves."
num_inference_steps: 25
guidance_scale: 8.
trainable_modules:
- "motion_modules."
unet_checkpoint_path: ""
learning_rate: 1.e-4
train_batch_size: 4
max_train_epoch: -1
max_train_steps: 100
checkpointing_epochs: -1
checkpointing_steps: 60
validation_steps: 5000
validation_steps_tuple: [2, 50]
global_seed: 42
mixed_precision_training: true
enable_xformers_memory_efficient_attention: True
is_debug: False

View File

@@ -1,2 +1 @@
gdown 1RqkQuGPaCO5sGZ6V6KZ-jUWmsRu48Kdq -O models/Motion_Module/
gdown 1ql0g_Ys4UCz2RnokYlBjyOYPbttbIpbu -O models/Motion_Module/
gdown 1EK_D9hDOPfJdK4z8YDB8JYvPracNx2SX -O models/Motion_Module/

View File

@@ -0,0 +1,2 @@
#!/bin/bash
wget https://civitai.com/api/download/models/169718 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate

View File

@@ -1,2 +0,0 @@
#!/bin/bash
wget https://civitai.com/api/download/models/78775 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate

View File

@@ -0,0 +1,2 @@
#!/bin/bash
wget https://civitai.com/api/download/models/126688 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate

View File

@@ -1,2 +0,0 @@
#!/bin/bash
wget https://civitai.com/api/download/models/72396 -P models/DreamBooth_LoRA/ --content-disposition --no-check-certificate

Some files were not shown because too many files have changed in this diff Show More