Merge branch 'master-github' into master-merge-github231228

This commit is contained in:
ly119399
2023-12-28 19:15:06 +08:00
121 changed files with 18192 additions and 177 deletions

View File

@@ -15,10 +15,10 @@ jobs:
#if: startsWith(github.event.ref, 'refs/tags')
steps:
- uses: actions/checkout@v2
- name: Set up Python 3.7
- name: Set up Python 3.10
uses: actions/setup-python@v2
with:
python-version: '3.7'
python-version: '3.10'
- name: Install wheel
run: pip install wheel && pip install -r requirements/framework.txt
- name: Build ModelScope

View File

@@ -53,70 +53,64 @@ Some representative examples include:
NLP:
* [nlp_gpt3_text-generation_2.7B](https://modelscope.cn/models/damo/nlp_gpt3_text-generation_2.7B)
* [ChatGLM3-6B](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary)
* [ChatYuan-large](https://modelscope.cn/models/ClueAI/ChatYuan-large)
* [Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary)
* [mengzi-t5-base](https://modelscope.cn/models/langboat/mengzi-t5-base)
* [Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat/summary)
* [nlp_csanmt_translation_en2zh](https://modelscope.cn/models/damo/nlp_csanmt_translation_en2zh)
* [Ziya2-13B-Chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
* [nlp_raner_named-entity-recognition_chinese-base-news](https://modelscope.cn/models/damo/nlp_raner_named-entity-recognition_chinese-base-news)
* [Internlm-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm-chat-20b/summary)
* [nlp_structbert_word-segmentation_chinese-base](https://modelscope.cn/models/damo/nlp_structbert_word-segmentation_chinese-base)
* [Udever Multilingual Universal Text Representation Model 1b1](https://modelscope.cn/models/damo/udever-bloom-1b1/summary)
* [Erlangshen-RoBERTa-330M-Sentiment](https://modelscope.cn/models/fengshenbang/Erlangshen-RoBERTa-330M-Sentiment)
* [CoROM Text Vector - Chinese - E-commerce Domain - Base](https://modelscope.cn/models/damo/nlp_corom_sentence-embedding_chinese-base-ecom/summary)
* [nlp_convai_text2sql_pretrain_cn](https://modelscope.cn/models/damo/nlp_convai_text2sql_pretrain_cn)
* [MGeo Address Similarity Matching Entity Alignment - Chinese - Address Field - Base](https://modelscope.cn/models/damo/mgeo_geographic_entity_alignment_chinese_base/summary)
Multi-Modal:
* [multi-modal_clip-vit-base-patch16_zh](https://modelscope.cn/models/damo/multi-modal_clip-vit-base-patch16_zh)
* [Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)
* [ofa_pretrain_base_zh](https://modelscope.cn/models/damo/ofa_pretrain_base_zh)
* [CogVLM](https://modelscope.cn/models/ZhipuAI/CogVLM/summary)
* [Taiyi-Stable-Diffusion-1B-Chinese-v0.1](https://modelscope.cn/models/fengshenbang/Taiyi-Stable-Diffusion-1B-Chinese-v0.1)
* [Text-to-Video Synthesis Large Model - English - General Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary)
* [mplug_visual-question-answering_coco_large_en](https://modelscope.cn/models/damo/mplug_visual-question-answering_coco_large_en)
* [I2VGen-XL High Definition Image to Video Large Model](https://modelscope.cn/models/damo/Image-to-Video/summary)
* [I2VGen-XL High Definition Video to Video Large Model](https://modelscope.cn/models/damo/Video-to-Video/summary)
CV:
* [cv_controlnet_controllable-image-generation_nine-annotators](https://modelscope.cn/models/dienstag/cv_controlnet_controllable-image-generation_nine-annotators/summary)
* [DamoFD Face Detection Key Point Model - 0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
* [cv_tinynas_object-detection_damoyolo](https://modelscope.cn/models/damo/cv_tinynas_object-detection_damoyolo)
* [BSHM Portrait Matting](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
* [cv_unet_person-image-cartoon_compound-models](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models)
* [DCT-Net Portrait Cartoonization - 3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
* [cv_convnextTiny_ocr-recognition-general_damo](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo)
* [DCT-Net Portrait Cartoonization Model - 3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
* [cv_resnet18_human-detection](https://modelscope.cn/models/damo/cv_resnet18_human-detection)
* [DuGuang - Text Recognition - Line Recognition Model - Chinese and English - General Domain](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary)
* [cv_resnet50_face-detection_retinaface](https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface)
* [DuGuang - Text Recognition - Line Recognition Model - Chinese and English - General Domain](https://modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary)
* [cv_unet_image-matting](https://modelscope.cn/models/damo/cv_unet_image-matting)
* [cv_F3Net_product-segmentation](https://modelscope.cn/models/damo/cv_F3Net_product-segmentation)
* [cv_resnest101_general_recognition](https://modelscope.cn/models/damo/cv_resnest101_general_recognition)
* [LaMa Image Inpainting](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
Audio:
* [speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
* [Paraformer Speech Recognition - Chinese - General - 16k - Offline - Large - Long Audio Version](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
* [speech_sambert-hifigan_tts_zh-cn_16k](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k)
* [FSMN Voice Endpoint Detection - Chinese - General - 16k - onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
* [speech_charctc_kws_phone-xiaoyun](https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun)
* [Monotonic-Aligner Speech Timestamp Prediction - 16k - Offline](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
* [u2pp_conformer-asr-cn-16k-online](https://modelscope.cn/models/wenet/u2pp_conformer-asr-cn-16k-online)
* [CT-Transformer Punctuation - Chinese - General - onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
* [speech_fsmn_vad_zh-cn-16k-common-pytorch](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary)
* [Speech Synthesis - Chinese - Multiple Emotions Domain - 16k - Multiple Speakers](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k/summary)
* [punc_ct-transformer_zh-cn-common-vocab272727-pytorch](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)
* [speech_frcrn_ans_cirm_16k](https://modelscope.cn/models/damo/speech_frcrn_ans_cirm_16k)
* [speech_dfsmn_aec_psm_16k](https://modelscope.cn/models/damo/speech_dfsmn_aec_psm_16k)
* [CAM++ Speaker Verification - Chinese - General - 200k-Spkrs](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
@@ -208,7 +202,7 @@ CPU docker image
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
```
GPU docker image
@@ -217,15 +211,16 @@ GPU docker image
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
```
## Setup Local Python Environment
One can also set up local ModelScope environment using pip and conda. We suggest [anaconda](https://docs.anaconda.com/anaconda/install/) for creating local python environment:
One can also set up local ModelScope environment using pip and conda. ModelScope supports python3.7 and above.
We suggest [anaconda](https://docs.anaconda.com/anaconda/install/) for creating local python environment:
```shell
conda create -n modelscope python=3.7
conda create -n modelscope python=3.8
conda activate modelscope
```

View File

@@ -208,7 +208,7 @@ CPU docker イメージ
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
```
GPU docker イメージ
@@ -217,7 +217,7 @@ GPU docker イメージ
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
```
## ローカル Python 環境のセットアップ

View File

@@ -37,6 +37,8 @@ ModelScope Library为模型贡献者提供了必要的分层API以便将来
除了包含各种模型的实现之外ModelScope Library还支持与ModelScope后端服务进行必要的交互特别是与Model-Hub和Dataset-Hub的交互。这种交互促进了模型和数据集的管理在后台无缝执行包括模型数据集查询、版本控制、缓存管理等。
# 部分模型和在线体验
ModelScope开源了数百个(当前700+)模型涵盖自然语言处理、计算机视觉、语音、多模态、科学计算等其中包含数百个SOTA模型。用户可以进入ModelScope网站([modelscope.cn](http://www.modelscope.cn))的模型中心零门槛在线体验或者Notebook方式体验模型。
@@ -50,68 +52,67 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
自然语言处理:
* [GPT-3预训练生成模型-中文-2.7B](https://modelscope.cn/models/damo/nlp_gpt3_text-generation_2.7B)
* [ChatGLM3-6B](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary)
* [元语功能型对话大模型](https://modelscope.cn/models/ClueAI/ChatYuan-large)
* [Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary)
* [孟子T5预训练生成模型-中文-base](https://modelscope.cn/models/langboat/mengzi-t5-base)
* [Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat/summary)
* [CSANMT连续语义增强机器翻译-英中-通用领域-large](https://modelscope.cn/models/damo/nlp_csanmt_translation_en2zh)
* [Ziya2-13B-Chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
* [RaNER命名实体识别-中文-新闻领域-base](https://modelscope.cn/models/damo/nlp_raner_named-entity-recognition_chinese-base-news)
* [Internlm-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm-chat-20b/summary)
* [BAStructBERT分词-中文-新闻领域-base](https://modelscope.cn/models/damo/nlp_structbert_word-segmentation_chinese-base)
* [Udever-bloom-1b1](https://modelscope.cn/models/damo/udever-bloom-1b1/summary)
* [二郎神-RoBERTa-330M-情感分类](https://modelscope.cn/models/fengshenbang/Erlangshen-RoBERTa-330M-Sentiment)
* [CoROM文本向量-中文-电商领域-base](https://modelscope.cn/models/damo/nlp_corom_sentence-embedding_chinese-base-ecom/summary)
* [SPACE-T表格问答预训练模型-中文-通用领域-base](https://modelscope.cn/models/damo/nlp_convai_text2sql_pretrain_cn)
* [MGeo地址相似度匹配实体对齐-中文-地址领域-base](https://modelscope.cn/models/damo/mgeo_geographic_entity_alignment_chinese_base/summary)
多模态:
* [CLIP模型-中文-通用领域-base](https://modelscope.cn/models/damo/multi-modal_clip-vit-base-patch16_zh)
* [Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)
* [OFA预训练模型-中文-通用领域-base](https://modelscope.cn/models/damo/ofa_pretrain_base_zh)
* [CogVLM](https://modelscope.cn/models/ZhipuAI/CogVLM/summary)
* [太乙-Stable-Diffusion-1B-中文-v0.1](https://modelscope.cn/models/fengshenbang/Taiyi-Stable-Diffusion-1B-Chinese-v0.1)
* [Text-to-Video Synthesis Large Model - English - General Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary)
* [I2VGen-XL高清图片到视频大模型](https://modelscope.cn/models/damo/Image-to-Video/summary)
* [I2VGen-XL高清视频到视频大模型](https://modelscope.cn/models/damo/Video-to-Video/summary)
* [mPLUG视觉问答模型-英文-large](https://modelscope.cn/models/damo/mplug_visual-question-answering_coco_large_en)
计算机视觉:
* [ControlNet可控图像生成](https://modelscope.cn/models/dienstag/cv_controlnet_controllable-image-generation_nine-annotators/summary)
* [DamoFD人脸检测关键点模型-0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
* [DAMOYOLO-高性能通用检测模型-S](https://modelscope.cn/models/damo/cv_tinynas_object-detection_damoyolo)
* [BSHM人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
* [DCT-Net人像卡通化](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models)
* [DCT-Net人像卡通化-3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo)
* [DCT-Net人像卡通化模型-3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
* [人体检测-通用-Base](https://modelscope.cn/models/damo/cv_resnet18_human-detection)
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary)
* [RetinaFace人脸检测关键点模型](https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface)
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary)
* [BSHM人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting)
* [LaMa图像填充](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
* [图像分割-商品展示图场景的商品分割-电商领域](https://modelscope.cn/models/damo/cv_F3Net_product-segmentation)
* [万物识别-中文-通用领域](https://modelscope.cn/models/damo/cv_resnest101_general_recognition)
语音:
* [Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
* [Paraformer语音识别-中文-通用-16k-离线-大型-长音频版本](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
* [语音合成-中文-多情感领域-16k-多发音人](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k)
* [FSMN声音端点检测-中文-通用-16k-onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
* [CTC语音唤醒-移动端-单麦-16k-小云小云](https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun)
* [Monotonic-Aligner语音时间戳预测-16k-离线](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
* [WeNet-U2pp_Conformer-语音识别-中文-16k-实时](https://modelscope.cn/models/wenet/u2pp_conformer-asr-cn-16k-online)
* [FRCRN语音降噪-单麦-16k](https://modelscope.cn/models/damo/speech_frcrn_ans_cirm_16k)
* [DFSMN回声消除-单麦单参考-16k](https://modelscope.cn/models/damo/speech_dfsmn_aec_psm_16k)
* [CT-Transformer标点-中文-通用-onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
* [语音合成-中文-多情绪领域-16k-多发言人](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k/summary)
* [CAM++说话人验证-中文-通用-200k发言人](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
科学计算:
@@ -194,7 +195,7 @@ CPU镜像
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
```
GPU镜像
@@ -203,14 +204,14 @@ GPU镜像
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
# py38
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
```
## 搭建本地Python环境
你也可以使用pip和conda搭建本地python环境我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/)安装完成后执行如下命令为modelscope library创建对应的python环境
你也可以使用pip和conda搭建本地python环境ModelScope支持python3.7+以上环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/)安装完成后执行如下命令为modelscope library创建对应的python环境
```shell
conda create -n modelscope python=3.7
conda create -n modelscope python=3.8
conda activate modelscope
```

View File

@@ -119,7 +119,7 @@ git lfs install
2. We use a public read model repository from ModelScope to store test data. The repository has been added by default as a submodule with the path data/test. To clone it, use the following command:
```shell
git clone git@github.com:modelscope/modelscope.git --recursive
git clone https://github.com/modelscope/modelscope.git --recursive
```
3. Each time you add new data, go to the data/test directory (note that you are now in the submodule's git directory), check if you are on the master branch, and pull the latest master branch:

View File

@@ -90,8 +90,7 @@ git lfs install
2. 我们使用 ModelScope 的一个公共读取模型仓库来存储测试数据。该仓库已默认添加为子模块,路径为 data/test。要克隆它请使用以下命令
```
git clone git@github.com:modelscope/modelscope.git --recursive
git clone https://github.com/modelscope/modelscope.git --recursive
```
3. 每次添加新数据时,进入 data/test 目录(注意此时您已在子模块的 git 目录中),检查是否在 master 分支上,并拉取最新的 master 分支:

View File

@@ -0,0 +1,55 @@
# Oh No! I'm Surrounded by LLMs! (LLMRiddles)
## Project Introduction
"Oh No! I'm Surrounded by LLMs!" is an intellectual challenge game. We use LLM to automatically generate corresponding game code based on existing Large Language Model (LLM) dialogue Gradio application codes within the ModelScope community, combined with preset questions from the Zhihu article ["How to Accomplish Tasks with 'Impossible'"](https://zhuanlan.zhihu.com/p/665393240), creating a unique gameplay experience. In this stream, players are required to cleverly construct questions that challenge the LLM to provide answers that meet specific conditions.
## News
November 9, 2023 - Added two new questions, and introduced the chatglm-turbo model 🔥🔥🔥
November 7, 2023 - Released the initial demo version 🔥
November 8, 2023 - Segregated level modules and LLM, enabling independent integration of levels and LLM. Pull Requests welcome 🔥 🔥
## Getting Started
### Online Experience
[LLMRiddles](https://modelscope.cn/studios/LLMRiddles/LLMRiddles/summary)
### Local Execution
To start the game, please follow the steps below:
1. Clone the project code:
```
git clone https://github.com/modelscope/modelscope.git
```
2. Navigate to the `examples/apps/llm_riddles` directory.
3. Install the required Python dependencies with `pip install -r requirements.txt`.
4. Go to [DashScope](https://dashscope.aliyun.com/), activate the service, obtain a token, and configure the environment variable `DASHSCOPE_API_KEY=your API-KEY`.
5. Run the launch command `python app.py`.
## Roadmap
- [x] Initial version source code and space experience ready.
- [x] Support for custom questions and validation logic integration.
- [ ] Expand to 9 major levels, each with 9 questions.
- [ ] Support for more open-source models.
- [ ] Support for switching between cloud API and local inference.
## Contribution Guide
We welcome everyone to contribute to "Oh No! I'm Surrounded by LLMs!", including proposing more fun questions, fixing validator corner cases, and providing more gameplay. Please follow the steps below:
1. Visit the project address [ModelScope](https://github.com/modelscope/modelscope) and fork the project.
2. Create your feature branch in your local environment (`git checkout -b feature/AmazingFeature`).
3. Commit your changes (`git commit -m 'Add some AmazingFeature'`).
4. Push your changes to the branch (`git push origin feature/AmazingFeature`).
5. Initiate a Pull Request in the original project.
## Community Contributors
We sincerely thank all community members who have contributed to this project, especially:
- Idea from: [haoqiangfan](https://www.zhihu.com/people/haoqiang-fan)
- Most of the code is auto-generated by LLM
## Support
If you encounter any problems or need assistance during the game, please submit your issues on the project's [Issues page](https://github.com/modelscope/modelscope/issues).
## Copyright and License
This project is licensed under the APACHE License. Please see the [LICENSE](https://github.com/modelscope/modelscope/blob/main/LICENSE) file in the project for more information.

View File

@@ -0,0 +1,65 @@
# 完蛋我被LLM包围了(LLMRiddles)
## 项目简介
《完蛋我被LLM包围了》是一款智力挑战游戏。该项目利用LLM代码生成, 基于ModelScope社区内现有的LLM对话Gradio应用程序代码结合知乎文章[《如何用“不可能”完成任务》](https://zhuanlan.zhihu.com/p/665393240)中的预设问题自动生成了对应的游戏代码创造了一个独特的游戏体验。在这个游戏中玩家需要巧妙构造问题挑战LLM给出满足特定条件的回答。
## 更新
2023.11.9 新增两道题目, 新增chatglm-turbo模型🔥🔥🔥
2023.11.7 发布初版demo🔥
2023.11.8 拆分关卡模块和llm支持关卡独立接入llm独立接入 欢迎PR 🔥 🔥
## 开始游戏
### 在线体验
[LLMRiddles](https://modelscope.cn/studios/LLMRiddles/LLMRiddles/summary)
### 本地运行
要开始游戏,请按照以下步骤操作:
1. 克隆项目代码:
```
git clone https://github.com/modelscope/modelscope.git
```
2. 进入到`examples/apps/llm_riddles`目录。
3. 安装所需的Python依赖`pip install -r requirements.txt`。
4. 前往[DashScope](https://dashscope.aliyun.com/)开通服务获取token配置环境变量`DASHSCOPE_API_KEY=你的API-KEY`
5. 执行启动命令`python app.py`.
## RoadMap
- [x] 初版本源码和创空间体验ready
- [x] 支持自定义问题和验证逻辑接入
- [ ] 扩充到9个大关卡每个关卡9个问题
- [ ] 支持更多开源模型
- [ ] 支持云端API和本地推理切换
## 贡献指南
我们欢迎大家为《完蛋我被LLM包围了》做出贡献包括提出更多好玩的问题修复validator的corner case以及提供更多的玩法。请按以下步骤操作
1. 访问项目地址 [ModelScope](https://github.com/modelscope/modelscope) 并fork项目。
2. 在你的本地环境中创建你的特性分支 (`git checkout -b feature/AmazingFeature`)。
3. 提交你的改动 (`git commit -m 'Add some AmazingFeature'`)。
4. 将你的改动推送到分支上 (`git push origin feature/AmazingFeature`)。
5. 在原项目下发起一个Pull Request。
## 社区贡献者
我们诚挚感谢所有对本项目做出贡献的社区成员,特别是:
- idea来源: [haoqiangfan](https://www.zhihu.com/people/haoqiang-fan)
- 代码大部分来自于LLM自动生成
## 支持
如果你在游戏过程中遇到任何问题或需要帮助,请通过项目的[Issues页面](https://github.com/modelscope/modelscope/issues)提交你的问题。
## 版权和许可
本项目采用APACHE License许可证。请查看项目中的[LICENSE](https://github.com/modelscope/modelscope/blob/main/LICENSE)文件了解更多信息。

View File

@@ -0,0 +1,225 @@
import functools
import inspect
import os
import random
import re
import gradio as gr
from challenges.ch1 import challenge1
from challenges.ch2 import challenge2
from challenges.ch3 import challenge3
from challenges.ch4 import challenge4
from challenges.ch5 import challenge5
from llm import create_model
from PIL import Image, ImageDraw, ImageFont
model_cache = {}
# 定义关卡信息和验证逻辑
challenges = [
challenge1,
challenge2,
challenge3,
challenge4,
challenge5,
]
CONGRATS_STR = '所有挑战完成!👏🏻👏🏻👏🏻👏🏻👏🏻👏🏻'
CONGRATS_QUESTION = f'<center><font size=4>{CONGRATS_STR}</center>\n\n <center><font size=3> </center>'
SHARE_CHALLENGES_HINT = [
'小试牛刀新手上路', '数字玩家已经上线', '巅峰对决,你就是提示词高手', '无人之境,胜利就在前方', '哇塞我冲出了LLM的重围'
]
def get_problem(challenge_idx, problem_idx):
problems = challenges[challenge_idx]['problems']
return problems[problem_idx]
def update_challenge_info(current_chapter_index, current_challenge_index):
return get_problem(current_chapter_index,
current_challenge_index)['description']
def update_question_info(current_chapter_index, current_challenge_index):
global challenges
current_chapter = challenges[current_chapter_index]
challenge = get_problem(current_chapter_index, current_challenge_index)
question_info = f"""\n<center><font size=4>{current_chapter["name"]}""" \
f"""</center>\n\n <center><font size=3>{challenge["title"]}</center>"""
return question_info
def validate_challenge(response, input, state, generate_response):
if 'success' in state:
return CONGRATS_STR, CONGRATS_QUESTION, ''
assert 'current_chapter_index' in state, 'current_chapter_index not found in state'
assert 'current_challenge_index' in state, 'current_challenge_index not found in state'
current_chapter_index = state['current_chapter_index']
current_challenge_index = state['current_challenge_index']
# 获取当前章节
current_chapter = challenges[current_chapter_index]
# 获取当前挑战
challenge = current_chapter['problems'][current_challenge_index]
validate_fn = challenge['validator']
params = inspect.signature(validate_fn).parameters
if 'generate_response' in params:
valid_result = validate_fn(response, input, generate_response)
else:
valid_result = validate_fn(response, input)
if valid_result:
challenge_result = '挑战成功!进入下一关。'
# 检查是否还有更多挑战在当前章节
if current_challenge_index < len(current_chapter['problems']) - 1:
# 移动到当前章节的下一个挑战
current_challenge_index += 1
else:
# 如果当前章节的挑战已经完成,移动到下一个章节
if current_chapter_index < len(challenges) - 1:
current_challenge_index = 0
current_chapter_index += 1
else:
state['success'] = True
challenge_result = '所有挑战完成!'
else:
challenge_result = '挑战失败,请再试一次。'
state['current_chapter_index'] = current_chapter_index
state['current_challenge_index'] = current_challenge_index
print('update state: ', state)
if 'success' in state:
return CONGRATS_STR, CONGRATS_QUESTION, ''
else:
return challenge_result, \
update_question_info(current_chapter_index, current_challenge_index), \
update_challenge_info(current_chapter_index, current_challenge_index)
def generate_response(input, model_name):
if model_name in model_cache:
model = model_cache[model_name]
else:
model = create_model(model_name)
model_cache[model_name] = model
try:
return model(input)
except RuntimeError as e:
# if exception happens, print error in log and return empty str
print('error', e)
return ''
def on_submit(input, model_name, state):
# model_name = os.environ.get('MODEL', 'qwen-plus')
name_map = {
'qwen-max': 'qwen-max',
'qwen-plus': 'qwen-plus',
'chatglm-turbo': 'chatglm_turbo',
}
gen_fn = functools.partial(
generate_response, model_name=name_map[model_name])
response = gen_fn(input)
history = [(input, response)]
print(history)
challenge_result, question_info, challenge_info = validate_challenge(
response, input, state, gen_fn)
return challenge_result, history, question_info, challenge_info
def generate_share_image(state):
share_state = state['current_chapter_index']
if share_state > 3:
share_state = 3
if 'success' in state:
share_state = 4 # 全部通关为 4
img_pil = Image.open(f'assets/background{share_state}.png')
# 设置需要显示的字体
fontpath = 'assets/font.ttf'
font = ImageFont.truetype(fontpath, 48)
draw = ImageDraw.Draw(img_pil)
# 绘制文字信息
draw.text((70, 1000),
SHARE_CHALLENGES_HINT[share_state],
font=font,
fill=(255, 255, 255))
if share_state == 4:
share_chapter_text = '顺利闯过了全部关卡'
else:
share_chapter_text = f"我顺利闯到第 {state['current_chapter_index']+1}-{state['current_challenge_index']+1}"
draw.text((70, 1080), share_chapter_text, font=font, fill=(255, 255, 255))
draw.text((70, 1160), '你也来挑战一下吧~', font=font, fill=(255, 255, 255))
return gr.Image.update(visible=True, value=img_pil)
def create_app():
# Gradio界面构建
block = gr.Blocks()
with block as demo:
current_chapter_index = 0
current_challenge_index = 0
state = gr.State(
dict(
current_challenge_index=current_challenge_index,
current_chapter_index=current_chapter_index))
gr.Markdown("""<center><font size=6>完蛋我被LLM包围了</center>""")
gr.Markdown("""<font size=3>欢迎来玩LLM Riddles复刻版完蛋我被LLM包围了
你将通过本游戏对大型语言模型产生更深刻的理解。
在本游戏中,你需要构造一个提给一个大型语言模型的问题,使得它回复的答案符合要求。""")
model_selector = gr.Dropdown(
label='选择模型',
choices=['qwen-max', 'qwen-plus', 'chatglm-turbo'],
value='qwen-max')
question_info = gr.Markdown(
update_question_info(current_chapter_index,
current_challenge_index))
challenge_info = gr.Textbox(
value=update_challenge_info(current_chapter_index,
current_challenge_index),
label='当前挑战',
interactive=False)
challenge_result = gr.Textbox(label='挑战结果', interactive=False)
chatbot = gr.Chatbot(label='llm', elem_classes='control-height')
message = gr.Textbox(lines=2, label='输入')
with gr.Row():
submit = gr.Button('🚀 发送')
shareBtn = gr.Button('💯 分享成绩')
shareImg = gr.Image(label='分享成绩', visible=False, width=400)
submit.click(
on_submit,
inputs=[message, model_selector, state],
outputs=[challenge_result, chatbot, question_info, challenge_info])
shareBtn.click(
generate_share_image, inputs=[state], outputs=[shareImg])
gr.HTML("""
<div style="text-align: center;">
<span>
Powered by <a href="https://dashscope.aliyun.com/" target="_blank">
<img src=
"//img.alicdn.com/imgextra/i4/O1CN01SgKFXM1qLQwFvk6j5_!!6000000005479-2-tps-99-84.png"
style="display: inline; height: 20px; vertical-align: bottom;"/>DashScope
</a>
</span>
</div>
""")
demo.queue(concurrency_count=10).launch(height=800, share=True)
if __name__ == '__main__':
create_app()

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8afcec15a87bcfaff327a5c9564a31ff1fe185a63cb286bd9772c8c68216768a
size 757003

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:16afb18994ad0654b31117931aad2ee05863492e964e10f4c559556e29618320
size 839643

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:8afcec15a87bcfaff327a5c9564a31ff1fe185a63cb286bd9772c8c68216768a
size 757003

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:966a013913042e1574ccbc299b1914272cb47df69a552bf1723b96b2d8902de3
size 1114172

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:5253bbed99be55e6ac9080ea320df75c95592204696d6d41ba90f9905384fdca
size 1198295

View File

@@ -0,0 +1,3 @@
version https://git-lfs.github.com/spec/v1
oid sha256:4cf462f8db7583843bc152ccfc87bb033b91880c98db9f83ba87fcca5d5d07f2
size 1056053

Binary file not shown.

View File

@@ -0,0 +1,100 @@
def validate_palindrome_invariance(origin_input, response, generate_response):
"""
验证器函数,检查对于给定的问题,检验是否是回文串,且正着问和倒着问的模型回答是否一致。
:param origin_input: 原始问题
:return: True 如果回答一致,否则 False
"""
# 检查原始输入不是回文
if origin_input == origin_input[::-1]:
return False
# 将原始问题倒序并提交
reversed_question = origin_input[::-1]
reversed_response = generate_response(reversed_question)
# 比较两个回答是否一致
return response.strip() == reversed_response.strip()
challenge1 = {
'name':
'第一章 对话之趣',
'problems': [
{
'title': '第1题 初来乍到',
'description': '请你构造一个问题使模型的回答是一字不差的“1+1=3”不需要引号',
'validator': lambda response, input: response.strip() == '1+1=3'
},
{
'title':
'第2题 小试牛刀',
'description':
'请输入三个字以内的问题使模型的回答在30个字以上。',
'validator':
lambda response, input: len(input) <= 3 and len(response) > 30
},
{
'title':
'第3题 短说长话',
'description':
'请输入一个字的问题使模型的回答在100个字以上。',
'validator':
lambda response, input: len(input) == 1 and len(response) > 100
},
{
'title':
'第4题 短说短话',
'description':
'请输入一个字的问题使模型的回答在20个字以内。',
'validator':
lambda response, input: len(input) == 1 and len(response) <= 20
},
{
'title':
'第5题 回文不变',
'description':
'请输入一个本身不是回文串的问题,使无论正着问还是倒着问,模型的回答是一样的。',
# 验证器将取决于模型的行为,这里只是一个占位符
'validator':
lambda response, input, generate_response: validate_palindrome_invariance(
input, response, generate_response)
},
{
'title':
'第6题 无中生狗',
'description':
'请提一个不包含“狗”这个字的问题但是回答中至少出现3次“狗”这个字。',
'validator':
lambda response, input: '' not in input and response.count('') \
>= 3
},
{
'title':
'第7题 事半功倍',
'description':
'请输入一个本身不是回文串的问题,使模型回答的长度刚好是问题长度的两倍。',
'validator':
lambda response, input: input != input[::-1] and len(response.strip()) == 2 * len(input)
},
{
'title':
'第8题 产生疑惑',
'description':
'请提出一个问题,不包含“不”“知”“道”三个字符,但要使模型的回答是一字不差的“不知道”',
'validator':
lambda response, input:
all(char not in input for char in ['', '', '']) and # noqa: W504
response.strip() == '不知道'
},
{
'title':
'第9题 无中生有',
'description':
'请提出一个问题,使模型只回复“永动机是真实存在的。”',
'validator':
lambda response, input: response.strip() == '永动机是真实存在的。'
},
]
}

View File

@@ -0,0 +1,131 @@
import re
import sympy
# 辅助函数 - 检查是否为质数
def is_prime(num):
return sympy.isprime(num)
# 辅助函数 - 获取下一个质数
def next_prime(num):
return sympy.nextprime(num)
# 辅助函数 - 检查是否为平方数
def is_square(n):
return sympy.sqrt(n).is_integer
# 辅助函数 - 获取平方根
def get_square_root(n):
return int(sympy.sqrt(n))
# 验证函数 - 微言大义
def validate_9(response, input):
input_yes = len(input) <= 10
output_yes = len(response) >= 9 and response.isdigit() and sympy.isprime(
int(response))
return input_yes and output_yes
challenge2 = {
'name':
'第二章 数字游戏',
'problems': [
{
'title':
'第1题 质数长度',
'description':
'你需要提出一个字数是质数的问题,使回答的长度刚好是它的下一个质数。',
'validator':
lambda response, input:
(is_prime(len(input)) and len(response) == next_prime(len(input)))
},
{
'title':
'第2题 越说越大',
'description':
'请输入一个大于一的正整数作为问题,使回答里包含至少比它大一千的数。',
'validator':
lambda response, input:
(input.isdigit() and int(input) > 1 and any(
int(num) > int(input) + 1000
for num in re.findall(r'\d+', response)))
},
{
'title':
'第3题 越说越小',
'description':
'请输入一个大于一的正整数作为问题使回答里包含至少10个不同的至少比它小一千的自然数。',
'validator':
lambda response, input:
(input.isdigit() and int(input) > 1 and (len({
int(num)
for num in re.findall(r'\d+', response)
if int(num) < int(input) - 1000
}) >= 10))
},
{
'title':
'第4题 警世名言',
'description':
'请在不提及1,4,5这三个字符的情况下让模型回答114514不要有多余的字符。',
'validator':
lambda response, input: response.strip() == '114514' and all(
c not in input for c in ['1', '4', '5'])
},
{
'title':
'第5题 开平方数',
'description':
'请输入一个大于一的平方数作为问题,使回答里包含它的平方根对应的数。',
'validator':
lambda response, input: ((
input.isdigit() and int(input) > 1 and is_square(int(input))
and # noqa: W504
str(get_square_root(int(input))) in response))
},
{
'title':
'第6题 得寸进狗',
'description':
'请提一个不包含“狗”这个字的10个字以内的问题使回答中“狗”这个字出现的次数至少是问题字数的2倍。',
'validator':
lambda response, input: (
'' not in input and len(input) <= 10 and # noqa: W504
response.count('') >= 2 * len(input))
},
{
'title':
'第7题 暗藏玄机',
'description':
'请输入一个大于一的正整数作为问题使回答包含一个两位数其十位上的数字比个位上的数字小5。',
'validator':
lambda response, input:
(input.isdigit() and int(input) > 1 and (len({
int(num)
for num in re.findall(r'\d+', response)
if len(num) == 2 and int(num[1]) - int(num[0]) == 5
}) >= 1))
},
{
'title':
'第8题 点兵点将',
'description':
'请输入十个字以内的问题问题中不能包含0到9的数字以及零到十的汉字使模型的回答中包含数字0到9。',
'validator':
lambda response, input:
(len(input) <= 10 and not re.search(r'\d', input) and all(
char not in input for char in '零一二三四五六七八九十') and len(
set(re.findall(r'\d', response))) == 10)
},
{
'title': '第9题 微言大义',
'description': '请输入10个字以内的问题使得模型的回答是一个超过一亿的素数',
'validator': validate_9
}
]
}

View File

@@ -0,0 +1,101 @@
import re
import unicodedata
def validate_palindrome_inverse(origin_input, response, generate_response):
"""
验证器函数,检查对于给定的问题,正着问和倒着问的模型的回答本身不回文且也是逆序的关系。
:param origin_input: 原始问题
:param response: 模型对原始问题的回答
:param model_fn: 能够返回模型回答的函数
:return: True 如果回答本身不回文且也是逆序关系,否则 False
"""
# 检查原始输入不是回文
if origin_input == origin_input[::-1]:
return False
# 将原始问题倒序并获取模型回答
reversed_question = origin_input[::-1]
reversed_response = generate_response(reversed_question).strip()
# 检查原始回答不是回文
if response == response[::-1]:
return False
# 比较正着问和倒着问的回答是否是逆序关系
return response == reversed_response[::-1]
challenge3 = {
'name':
'第三章 巅峰挑战',
'problems': [
{
'title':
'第1题 自说自话',
'description':
'请做出一个至少10个字的提问使回答和提问完全一样。',
'validator':
lambda response, input: input == response and len(input) >= 10
},
{
'title':
'第2题 话自说自',
'description':
'请做出一个至少10个字的提问使回答是提问的逆序。',
'validator':
lambda response, input:
(input[::-1] == response and len(input) >= 10)
},
{
'title':
'第3题 就是敢说',
'description':
'请在不提及1,3这两个字符的情况下让模型回答1+1=3不要有多余的字符。',
'validator':
lambda response, input: response.strip() == '1+1=3' and all(
c not in input for c in ['1', '3'])
},
{
'title':
'第4题 回文协变',
'description':
'请输入一个本身不是回文串的问题,使正着问和倒着问的时候,模型的回答本身不回文且也是逆序的关系。',
'validator':
lambda response, input, generate_response:
validate_palindrome_inverse(input, response, generate_response)
},
{
'title':
'第5题 答非所问',
'description':
'请输入一个以“为什么”开头,以“?”结尾的不包含数字的问题,使回答中只包含数字。',
'validator':
lambda response, input: '为什么' == input[:3] and '' == input[-1]
and # noqa: W504
len(re.findall(r'\d+', input)) == 0 and response.isdigit()
},
{
'title':
'第6题 与狗交谈',
'description':
'请在不提及“狗”、“汪”、“dog”的情况下让模型的回答是一字不差的“汪~”',
'validator':
lambda response, input:
(all(char not in input
for char in ['', '', 'dog']) and response.strip() == '汪~')
},
{
'title':
'第7题 魔搭 (🐶)',
'description':
'请提一个不包含任何中文字符、换行或空格的问题,使模型正好回答"魔搭"',
'validator':
lambda response, input: not any(
'CJK' in unicodedata.name(char, '') or char in '\t\n '
for char in input) and (response.strip() == '魔搭')
},
]
}

View File

@@ -0,0 +1,68 @@
import re
def validate_reciprocal_question(input, response, generate_response):
"""
验证器函数检查给定的问题A和回答B是否能够通过以B作为新的提问得到原始问题A作为回答。
:param response: 模型对问题A的回答B
:param input: 原始问题A
:param model_fn: 模型调用函数,接收问题并返回回答
:return: True 如果以回答B作为新的提问能得到问题A作为回答否则 False
"""
# 确保回答B不等于原始问题A
if response == input:
return False
# 以回答B作为新的提问
new_response = generate_response(response)
# 检查新的回答是否与原始问题A相等
return new_response.strip() == input.strip()
challenge4 = {
'name':
'第四章 无人之境',
'problems': [
{
'title':
'第1题 循环往复',
'description':
'请构造一个问题A它的回复是不同于A的B然后以B提问能再次得到A.',
'validator':
lambda response, input, generate_response:
validate_reciprocal_question(input, response, generate_response)
},
{
'title':
'第2题 惜字如金',
'description':
'请输入一个字的问题使模型的回答在16个字以内。',
'validator':
lambda response, input: len(input) == 1 and len(response) <= 16
},
{
'title':
'第3题 自然之密',
'description':
'请输入一个大于一的正整数作为问题使回答里包含和它刚好相差1的数。',
'validator':
lambda response, input:
(input.isdigit() and int(input) > 1 and any(
abs(int(num) - int(input)) == 1
for num in re.findall(r'\d+', response)))
},
{
'title':
'第4题 八仙过海',
'description':
'请提出8个字的问题不包含“八”或“8”或“eight”使得模型的回答也是8个字。',
'validator':
lambda response, input:
(all(char not in input
for char in ['', '8', 'eight']) and len(response) == 8)
},
]
}

View File

@@ -0,0 +1,35 @@
def check_word_in_sentence(words, sentence):
return [word in sentence for word in words]
challenge5 = {
'name':
'第五章 登堂入室',
'problems': [
{
'title':
'第1题 盛夏少年',
'description':
'模型的回答应该包含“盛夏”、“蝉鸣”、“少年”、“橘子味汽水”这几个词,同时输入的问题不能包含其中任一个词。',
'validator':
lambda response, input: all(
check_word_in_sentence(['盛夏', '蝉鸣', '少年', '橘子味汽水'], response))
and not any(
check_word_in_sentence(['盛夏', '蝉鸣', '少年', '橘子味汽水'], input))
},
{
'title':
'第2题 蝉鸣日出',
'description':
'模型的回答应该包含“盛夏”、“蝉鸣”、“少年”、“橘子味汽水”、“日出”这几个词,同时输入的问题不能包含其中任一个字。',
'validator':
lambda response, input: all(
check_word_in_sentence(
['盛夏', '蝉鸣', '少年', '橘子味汽水', '日出'], response)) and not any(
check_word_in_sentence([
'', '', '', '', '', '', '', '', '', '',
'', '', ''
], input))
},
]
}

View File

@@ -0,0 +1,28 @@
from app import challenges, generate_response
def check_answer(chap_idx,
challenge_idx,
input='input',
model_name='qwen-max'):
print('{}章 第{}'.format(chap_idx + 1, challenge_idx + 1))
challenge = challenges[chap_idx]['problems'][challenge_idx]
print(challenge['description'])
val_fn = challenge['validator']
response = generate_response(input, model_name)
try:
res = val_fn(response, input)
print('input:\n', input)
print('response:\n', response)
print('validation result: ', res)
except Exception:
import traceback
traceback.print_exc()
print('failed')
if __name__ == '__main__':
chap = 5
ques = 1
input = '请使用“盛 夏”、“蝉 鸣”、“少 年”、“橘 子味汽水”这几个词造句'
check_answer(chap - 1, ques - 1, input)

View File

@@ -0,0 +1,170 @@
import os
import random
from http import HTTPStatus
from typing import Any, Dict, List, Union
class DashScope:
"""A class to interact with the Dashscope AI service for response generation.
This class provides an interface to call a specific model from the Dashscope service
to generate responses based on the input provided.
Attributes:
model (str): The name of the model to be used for generation.
"""
def __init__(self, model_name: str = 'qwen-plus'):
"""Initializes the DashScope instance with a given model name.
The constructor sets up the model name that will be used for response generation
and initializes the Dashscope API key from environment variables.
Args:
model_name (str): The name of the model to be used. Defaults to 'qwen-plus'.
"""
import dashscope # Import dashscope module at runtime
dashscope.api_key = os.getenv(
'DASHSCOPE_API_KEY') # Set the API key from environment variable
self.model: str = model_name # Assign the model name to an instance variable
def __call__(self, input: Union[str, List[Dict[str, str]]],
**kwargs: Any) -> Union[str, None]:
"""Allows the DashScope instance to be called as a function.
This method processes the input, sends it to the Dashscope service, and returns
the generated response.
Args:
input (Union[str, List[Dict[str, str]]]): The input str to generate a
response for. Can be a string or a list of messages.
**kwargs: Arbitrary keyword arguments.
Returns:
Union[str, None]: The generated response from the model, or None if there is an error.
Raises:
RuntimeError: If there is an error in accessing the Dashscope service.
"""
import dashscope # Import dashscope module at runtime
# Format the input into the required structure
if isinstance(input, str):
messages: List[Dict[str, str]] = [{
'role':
'system',
'content':
'You are a helpful assistant.'
}, {
'role': 'user',
'content': input
}]
else:
messages = input
# Make a call to the Dashscope service with the processed input
response = dashscope.Generation.call(
model=self.model,
messages=messages,
seed=random.randint(1,
10000), # Generate a random seed for each call
result_format='message', # Specify the format of the result
top_p=kwargs.get('top_p',
0.8) # Set the nucleus sampling parameter
)
# Check the response status code and return the generated response or raise an error
if response.status_code == HTTPStatus.OK:
return response.output.choices[0].message.content
else:
print('Error accessing dashscope, please try again.',
response.request_id, response.message)
return ''
class ZhiPu:
def __init__(self, model_name: str = 'chatglm_turbo'):
"""Initializes the ZhiPu instance with a given model name.
The constructor sets up the model name that will be used for response generation
and initializes the Dashscope API key from environment variables.
Args:
model_name (str): The name of the model to be used. Defaults to 'qwen-plus'.
"""
import zhipuai # Import dashscope module at runtime
zhipuai.api_key = os.getenv(
'ZHIPU_API_KEY') # Set the API key from environment variable
self.model: str = model_name # Assign the model name to an instance variable
def __call__(self, input: Union[str, List[Dict[str, str]]],
**kwargs: Any) -> Union[str, None]:
"""Allows the ZhiPu instance to be called as a function.
{
"code":200,
"msg":"操作成功",
"data":{
"request_id":"8098024428488935671",
"task_id":"8098024428488935671",
"task_status":"SUCCESS",
"choices":[
{
"role":"assistant",
"content":"\" 您好!作为人工智能助手,我很乐意为您提供帮助。请问您有什么问题或者需要解决的事情吗?您可以向我提问,我会尽力为您解答。\""
}
],
"usage":{
"prompt_tokens":2,
"completion_tokens":32,
"total_tokens":34
}
},
"success":true
}
"""
import zhipuai
if isinstance(input, str):
messages: List[Dict[str, str]] = [{
'role': 'user',
'content': input
}]
else:
messages = input
response = zhipuai.model_api.invoke(
model=self.model,
prompt=messages,
top_p=0.7,
temperature=0.9,
return_type='text',
)
if response['code'] == 200:
return response['data']['choices'][0]['content']
else:
print(f'{self.model} error: ', response)
return ''
def create_model(model_name: str):
"""Factory function to create a DashScope model instance based on the model name.
Args:
model_name (str): The name of the model to create an instance of.
Returns:
DashScope: An instance of the DashScope class.
Raises:
ValueError: If the model name provided does not start with 'qwen'.
"""
if model_name.startswith('qwen'):
return DashScope(model_name)
elif model_name.startswith('chatglm'):
return ZhiPu(model_name)
else:
raise ValueError('Other model implementations need to be provided.')
if __name__ == '__main__':
model = create_model('chatglm_turbo')
print(model('输入'))

View File

@@ -0,0 +1,5 @@
dashscope
gradio==3.39.0
pillow
sympy
zhipuai

View File

@@ -0,0 +1,13 @@
from app import challenges
def test_valid():
for challenge in challenges:
for p in challenge['problems']:
val_fn = p['validator']
try:
val_fn('response', 'input')
except Exception:
import traceback
traceback.print_exc()
print(p, 'failed')

View File

@@ -2,6 +2,7 @@
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
from .utils.automodel_utils import fix_transformers_upgrade
if TYPE_CHECKING:
from .exporters import Exporter, TfModelExporter, TorchModelExporter
@@ -33,7 +34,8 @@ if TYPE_CHECKING:
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
AutoModelForTokenClassification, AutoTokenizer,
GenerationConfig)
GenerationConfig, AutoImageProcessor,
BatchFeature)
from .utils.hub import create_model_if_not_exist, read_config
from .utils.logger import get_logger
from .version import __release_datetime__, __version__
@@ -81,7 +83,8 @@ else:
'BitsAndBytesConfig', 'AutoModelForCausalLM',
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
'AutoModelForSequenceClassification',
'AutoModelForTokenClassification'
'AutoModelForTokenClassification', 'AutoImageProcessor',
'BatchFeature'
],
'msdatasets': ['MsDataset']
}
@@ -95,3 +98,5 @@ else:
module_spec=__spec__,
extra_objects={},
)
fix_transformers_upgrade()

View File

@@ -127,6 +127,7 @@ class Models(object):
human_image_generation = 'human-image-generation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
anydoor = 'anydoor'
# nlp models
bert = 'bert'
@@ -457,6 +458,8 @@ class Pipelines(object):
human3d_animation = 'human3d-animation'
image_view_transform = 'image-view-transform'
image_control_3d_portrait = 'image-control-3d-portrait'
anydoor = 'anydoor'
image_to_3d = 'image-to-3d'
# nlp tasks
automatic_post_editing = 'automatic-post-editing'

View File

@@ -0,0 +1,20 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .anydoor_model import ControlLDM
else:
_import_structure = {'anydoor_model': ['ControlLDM']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,519 @@
import einops
import torch
import torch.nn as nn
from einops import rearrange, repeat
from torchvision.utils import make_grid
from modelscope import Model
from modelscope.metainfo import Models
from modelscope.models.builder import MODELS
from modelscope.utils.constant import Tasks
from .cldm.ddim_hacked import DDIMSampler
from .ldm.models.diffusion.ddpm import LatentDiffusion
from .ldm.modules.attention import SpatialTransformer
from .ldm.modules.diffusionmodules.openaimodel import (AttentionBlock,
Downsample, ResBlock,
TimestepEmbedSequential,
UNetModel)
from .ldm.modules.diffusionmodules.util import (conv_nd, linear,
timestep_embedding,
zero_module)
from .ldm.util import exists
class ControlledUnetModel(UNetModel):
def forward(self,
x,
timesteps=None,
context=None,
control=None,
only_mid_control=False,
**kwargs):
hs = []
with torch.no_grad():
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
if control is not None:
h += control.pop()
for i, module in enumerate(self.output_blocks):
if only_mid_control or control is None:
h = torch.cat([h, hs.pop()], dim=1)
else:
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
return self.out(h)
class ControlNet(nn.Module):
def __init__(
self,
image_size,
in_channels,
model_channels,
hint_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Need to include the dimension of your cross-attention conditioning'
if context_dim is not None:
assert use_spatial_transformer, 'Need to use the spatial transformer for your cross-attention conditioning'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.dims = dims
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
'provide num_res_blocks either as an int (globally constant) or '
'as a list/tuple (per-level) with the same length as channel_mult'
)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
],
range(len(num_attention_blocks))))
print(
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
f'attention will still not be set.')
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = torch.float16 if use_fp16 else torch.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
self.input_hint_block = TimestepEmbedSequential(
conv_nd(dims, hint_channels, 16, 3, padding=1), nn.SiLU(),
conv_nd(dims, 16, 16, 3, padding=1), nn.SiLU(),
conv_nd(dims, 16, 32, 3, padding=1, stride=2), nn.SiLU(),
conv_nd(dims, 32, 32, 3, padding=1), nn.SiLU(),
conv_nd(dims, 32, 96, 3, padding=1, stride=2), nn.SiLU(),
conv_nd(dims, 96, 96, 3, padding=1), nn.SiLU(),
conv_nd(dims, 96, 256, 3, padding=1, stride=2), nn.SiLU(),
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)))
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks
) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self.zero_convs.append(self.make_zero_conv(ch))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
) if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
self.zero_convs.append(self.make_zero_conv(ch))
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else
SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self.middle_block_out = self.make_zero_conv(ch)
self._feature_size += ch
def make_zero_conv(self, channels):
return TimestepEmbedSequential(
zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
def forward(self, x, hint, timesteps, context, **kwargs):
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb) # 1,1280
# 1,320,64,64
guided_hint = self.input_hint_block(hint, emb, context)
outs = []
h = x.type(self.dtype)
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
if guided_hint is not None:
# skip the first layer
h = guided_hint
guided_hint = None
else:
h_new = module(h, emb, context)
h = h_new
outs.append(zero_conv(h, emb, context))
h_new = self.middle_block(h, emb, context)
outs.append(self.middle_block_out(h_new, emb, context))
return outs
@MODELS.register_module(
Tasks.image_to_image_generation, module_name=Models.anydoor)
class ControlLDM(LatentDiffusion, Model):
'''
This work presents AnyDoor, a diffusion-based image generator
with the power to teleport target objects to new scenes
at user-specified locations in a harmonious way.
Instead of tuning parameters for each object, our model
is trained only once and effortlessly generalizes
to diverse object-scene combinations at the inference stage.
arxiv: https://arxiv.org/abs/2307.09481
'''
def __init__(self, control_stage_config, control_key, only_mid_control,
*args, **kwargs):
super().__init__(*args, **kwargs)
self.control_model = ControlNet(**control_stage_config)
self.control_key = control_key
self.only_mid_control = only_mid_control
self.control_scales = [1.0] * 13
@torch.no_grad()
def get_input(self, batch, k, bs=None, *args, **kwargs):
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
control = batch[self.control_key]
if bs is not None:
control = control[:bs]
control = control.to(self.device)
control = einops.rearrange(control, 'b h w c -> b c h w')
control = control.to(memory_format=torch.contiguous_format).float()
self.time_steps = batch['time_steps']
return x, dict(c_crossattn=[c], c_concat=[control])
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
assert isinstance(cond, dict)
diffusion_model = self.model.diffusion_model
cond_txt = torch.cat(cond['c_crossattn'], 1)
if cond['c_concat'] is None:
eps = diffusion_model(
x=x_noisy,
timesteps=t,
context=cond_txt,
control=None,
only_mid_control=self.only_mid_control)
else:
control = self.control_model(
x=x_noisy,
hint=torch.cat(cond['c_concat'], 1),
timesteps=t,
context=cond_txt)
control = [
c * scale for c, scale in zip(control, self.control_scales)
]
eps = diffusion_model(
x=x_noisy,
timesteps=t,
context=cond_txt,
control=control,
only_mid_control=self.only_mid_control)
return eps
@torch.no_grad()
def get_unconditional_conditioning(self, N):
uncond = self.get_learned_conditioning([torch.zeros(
(1, 3, 224, 224))] * N)
return uncond
@torch.no_grad()
def log_images(self,
batch,
N=4,
n_row=2,
sample=False,
ddim_steps=50,
ddim_eta=0.0,
return_keys=None,
quantize_denoised=True,
inpaint=True,
plot_denoise_rows=False,
plot_progressive_rows=True,
plot_diffusion_rows=False,
unconditional_guidance_scale=9.0,
unconditional_guidance_label=None,
use_ema_scope=True,
**kwargs):
use_ddim = ddim_steps is not None
log = dict()
z, c = self.get_input(batch, self.first_stage_key, bs=N)
c_cat, c = c['c_concat'][0][:N], c['c_crossattn'][0][:N]
N = min(z.shape[0], N)
n_row = min(z.shape[0], n_row)
log['reconstruction'] = self.decode_first_stage(z)
# ==== visualize the shape mask or the high-frequency map ====
guide_mask = (c_cat[:, -1, :, :].unsqueeze(1) + 1) * 0.5
guide_mask = torch.cat([guide_mask, guide_mask, guide_mask], 1)
HF_map = c_cat[:, :3, :, :] # * 2.0 - 1.0
log['control'] = HF_map
cond_image = batch[self.cond_stage_key].cpu().numpy().copy()
log['conditioning'] = torch.permute(
torch.tensor(cond_image), (0, 3, 1, 2)) * 2.0 - 1.0
if plot_diffusion_rows:
# get diffusion row
diffusion_row = list()
z_start = z[:n_row]
for t in range(self.num_timesteps):
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
t = t.to(self.device).long()
noise = torch.randn_like(z_start)
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
diffusion_row.append(self.decode_first_stage(z_noisy))
diffusion_row = torch.stack(
diffusion_row) # n_log_step, n_row, C, H, W
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
diffusion_grid = rearrange(diffusion_grid,
'b n c h w -> (b n) c h w')
diffusion_grid = make_grid(
diffusion_grid, nrow=diffusion_row.shape[0])
log['diffusion_row'] = diffusion_grid
if sample:
# get denoise row
samples, z_denoise_row = self.sample_log(
cond={
'c_concat': [c_cat],
'c_crossattn': [c]
},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta)
x_samples = self.decode_first_stage(samples)
log['samples'] = x_samples
if plot_denoise_rows:
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
log['denoise_row'] = denoise_grid
if unconditional_guidance_scale > 1.0:
uc_cross = self.get_unconditional_conditioning(N)
uc_cat = c_cat # torch.zeros_like(c_cat)
uc_full = {'c_concat': [uc_cat], 'c_crossattn': [uc_cross]}
samples_cfg, _ = self.sample_log(
cond={
'c_concat': [c_cat],
'c_crossattn': [c]
},
batch_size=N,
ddim=use_ddim,
ddim_steps=ddim_steps,
eta=ddim_eta,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=uc_full,
)
x_samples_cfg = self.decode_first_stage(samples_cfg)
log[f'samples_cfg_scale_{unconditional_guidance_scale:.2f}'] = x_samples_cfg # * 2.0 - 1.0
return log
@torch.no_grad()
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
ddim_sampler = DDIMSampler(self)
b, c, h, w = cond['c_concat'][0].shape
shape = (self.channels, h // 8, w // 8)
samples, intermediates = ddim_sampler.sample(
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
return samples, intermediates
def configure_optimizers(self):
lr = self.learning_rate
params = list(self.control_model.parameters())
if not self.sd_locked:
params += list(
self.model.diffusion_model.output_blocks.parameters())
params += list(self.model.diffusion_model.out.parameters())
params += list(self.cond_stage_model.projector.parameters())
opt = torch.optim.AdamW(params, lr=lr)
return opt
def low_vram_shift(self, is_diffusing):
if is_diffusing:
self.model = self.model.cuda()
self.control_model = self.control_model.cuda()
self.first_stage_model = self.first_stage_model.cpu()
self.cond_stage_model = self.cond_stage_model.cpu()
else:
self.model = self.model.cpu()
self.control_model = self.control_model.cpu()
self.first_stage_model = self.first_stage_model.cuda()
self.cond_stage_model = self.cond_stage_model.cuda()

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,428 @@
"""SAMPLING ONLY."""
import numpy as np
import torch
from tqdm import tqdm
from ..ldm.modules.diffusionmodules.util import (extract_into_tensor,
make_ddim_sampling_parameters,
make_ddim_timesteps,
noise_like)
class DDIMSampler(object):
def __init__(self, model, schedule='linear', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device('cuda'):
attr = attr.to(torch.device('cuda'))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.,
verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None, # this has to come in the same format as the conditioning
dynamic_threshold=None,
ucg_schedule=None,
**kwargs):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
# x_T 1,4,64,64
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
model_t = self.model.apply_model(x, t, c)
model_uncond = self.model.apply_model(x, t,
unconditional_conditioning)
model_output = model_uncond + unconditional_guidance_scale * (
model_t - model_uncond)
if self.model.parameterization == 'v':
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == 'eps', 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
sqrt_one_minus_alphas[index],
device=device)
# current prediction for x_0
if self.model.parameterization != 'v':
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
callback=None):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
num_reference_steps = timesteps.shape[0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0], ),
timesteps[i],
device=self.model.device,
dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
noise_pred = e_t_uncond + unconditional_guidance_scale * (
noise_pred - e_t_uncond)
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
tmp = (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()
weighted_noise_pred = alphas_next[i].sqrt() * tmp * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (num_steps // return_intermediates
) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback:
callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise)
@torch.no_grad()
def decode(self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0], ),
step,
device=x_latent.device,
dtype=torch.long)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback:
callback(i)
return x_dec

View File

@@ -0,0 +1,364 @@
import cv2
import numpy as np
import torch
def mask_score(mask):
'''Scoring the mask according to connectivity.'''
mask = mask.astype(np.uint8)
if mask.sum() < 10:
return 0
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
cv2.CHAIN_APPROX_NONE)
cnt_area = [cv2.contourArea(cnt) for cnt in contours]
conc_score = np.max(cnt_area) / sum(cnt_area)
return conc_score
def sobel(img, mask, thresh=50):
'''Calculating the high-frequency map.'''
H, W = img.shape[0], img.shape[1]
img = cv2.resize(img, (256, 256))
mask = (cv2.resize(mask, (256, 256)) > 0.5).astype(np.uint8)
kernel = np.ones((5, 5), np.uint8)
mask = cv2.erode(mask, kernel, iterations=2)
Ksize = 3
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=Ksize)
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=Ksize)
sobel_X = cv2.convertScaleAbs(sobelx)
sobel_Y = cv2.convertScaleAbs(sobely)
scharr = cv2.addWeighted(sobel_X, 0.5, sobel_Y, 0.5, 0)
scharr = np.max(scharr, -1) * mask
scharr[scharr < thresh] = 0.0
scharr = np.stack([scharr, scharr, scharr], -1)
scharr = (scharr.astype(np.float32) / 255 * img.astype(np.float32)).astype(
np.uint8)
scharr = cv2.resize(scharr, (W, H))
return scharr
def resize_and_pad(image, box):
'''Fitting an image to the box region while keeping the aspect ratio.'''
y1, y2, x1, x2 = box
H, W = y2 - y1, x2 - x1
h, w = image.shape[0], image.shape[1]
r_box = W / H
r_image = w / h
if r_box >= r_image:
h_target = H
w_target = int(w * H / h)
image = cv2.resize(image, (w_target, h_target))
w1 = (W - w_target) // 2
w2 = W - w_target - w1
pad_param = ((0, 0), (w1, w2), (0, 0))
image = np.pad(image, pad_param, 'constant', constant_values=255)
else:
w_target = W
h_target = int(h * W / w)
image = cv2.resize(image, (w_target, h_target))
h1 = (H - h_target) // 2
h2 = H - h_target - h1
pad_param = ((h1, h2), (0, 0), (0, 0))
image = np.pad(image, pad_param, 'constant', constant_values=255)
return image
def expand_image_mask(image, mask, ratio=1.4):
h, w = image.shape[0], image.shape[1]
H, W = int(h * ratio), int(w * ratio)
h1 = int((H - h) // 2)
h2 = H - h - h1
w1 = int((W - w) // 2)
w2 = W - w - w1
pad_param_image = ((h1, h2), (w1, w2), (0, 0))
pad_param_mask = ((h1, h2), (w1, w2))
image = np.pad(image, pad_param_image, 'constant', constant_values=255)
mask = np.pad(mask, pad_param_mask, 'constant', constant_values=0)
return image, mask
def resize_box(yyxx, H, W, h, w):
y1, y2, x1, x2 = yyxx
y1, y2 = int(y1 / H * h), int(y2 / H * h)
x1, x2 = int(x1 / W * w), int(x2 / W * w)
y1, y2 = min(y1, h), min(y2, h)
x1, x2 = min(x1, w), min(x2, w)
return (y1, y2, x1, x2)
def get_bbox_from_mask(mask):
h, w = mask.shape[0], mask.shape[1]
if mask.sum() < 10:
return 0, h, 0, w
rows = np.any(mask, axis=1)
cols = np.any(mask, axis=0)
y1, y2 = np.where(rows)[0][[0, -1]]
x1, x2 = np.where(cols)[0][[0, -1]]
return (y1, y2, x1, x2)
def expand_bbox(mask, yyxx, ratio=[1.2, 2.0], min_crop=0):
y1, y2, x1, x2 = yyxx
ratio = np.random.randint(ratio[0] * 10, ratio[1] * 10) / 10
H, W = mask.shape[0], mask.shape[1]
xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
h = ratio * (y2 - y1 + 1)
w = ratio * (x2 - x1 + 1)
h = max(h, min_crop)
w = max(w, min_crop)
x1 = int(xc - w * 0.5)
x2 = int(xc + w * 0.5)
y1 = int(yc - h * 0.5)
y2 = int(yc + h * 0.5)
x1 = max(0, x1)
x2 = min(W, x2)
y1 = max(0, y1)
y2 = min(H, y2)
return (y1, y2, x1, x2)
def box2squre(image, box):
H, W = image.shape[0], image.shape[1]
y1, y2, x1, x2 = box
cx = (x1 + x2) // 2
cy = (y1 + y2) // 2
h, w = y2 - y1, x2 - x1
if h >= w:
x1 = cx - h // 2
x2 = cx + h // 2
else:
y1 = cy - w // 2
y2 = cy + w // 2
x1 = max(0, x1)
x2 = min(W, x2)
y1 = max(0, y1)
y2 = min(H, y2)
return (y1, y2, x1, x2)
def pad_to_square(image, pad_value=255, random=False):
H, W = image.shape[0], image.shape[1]
if H == W:
return image
padd = abs(H - W)
if random:
padd_1 = int(np.random.randint(0, padd))
else:
padd_1 = int(padd / 2)
padd_2 = padd - padd_1
if H > W:
pad_param = ((0, 0), (padd_1, padd_2), (0, 0))
else:
pad_param = ((padd_1, padd_2), (0, 0), (0, 0))
image = np.pad(image, pad_param, 'constant', constant_values=pad_value)
return image
def box_in_box(small_box, big_box):
y1, y2, x1, x2 = small_box
y1_b, _, x1_b, _ = big_box
y1, y2, x1, x2 = y1 - y1_b, y2 - y1_b, x1 - x1_b, x2 - x1_b
return (y1, y2, x1, x2)
def shuffle_image(image, N):
height, width = image.shape[:2]
block_height = height // N
block_width = width // N
blocks = []
for i in range(N):
for j in range(N):
block = image[i * block_height:(i + 1) * block_height,
j * block_width:(j + 1) * block_width]
blocks.append(block)
np.random.shuffle(blocks)
shuffled_image = np.zeros((height, width, 3), dtype=np.uint8)
for i in range(N):
for j in range(N):
shuffled_image[i * block_height:(i + 1) * block_height,
j * block_width:(j + 1)
* block_width] = blocks[i * N + j]
return shuffled_image
def get_mosaic_mask(image, fg_mask, N=16, ratio=0.5):
ids = [i for i in range(N * N)]
masked_number = int(N * N * ratio)
masked_id = np.random.choice(ids, masked_number, replace=False)
height, width = image.shape[:2]
mask = np.ones((height, width))
block_height = height // N
block_width = width // N
b_id = 0
for i in range(N):
for j in range(N):
if b_id in masked_id:
mask[i * block_height:(i + 1) * block_height,
j * block_width:(j + 1)
* block_width] = mask[i * block_height:(i + 1)
* block_height, j * block_width:
(j + 1) * block_width] * 0
b_id += 1
mask = mask * fg_mask
mask3 = np.stack([mask, mask, mask], -1).copy().astype(np.uint8)
noise = q_x(image)
noise_mask = image * mask3 + noise * (1 - mask3)
return noise_mask
def extract_canney_noise(image, mask, dilate=True):
h, w = image.shape[0], image.shape[1]
mask = cv2.resize(mask.astype(np.uint8), (w, h)) > 0.5
kernel = np.ones((8, 8), dtype=np.uint8)
mask = cv2.erode(mask.astype(np.uint8), kernel, 10)
canny = cv2.Canny(image, 50, 100) * mask
kernel = np.ones((8, 8), dtype=np.uint8)
mask = (cv2.dilate(canny, kernel, 5) > 128).astype(np.uint8)
mask = np.stack([mask, mask, mask], -1)
pure_noise = q_x(image, t=1) * 0 + 255
canny_noise = mask * image + (1 - mask) * pure_noise
return canny_noise
def get_random_structure(size):
choice = np.random.randint(1, 5)
if choice == 1:
return cv2.getStructuringElement(cv2.MORPH_RECT, (size, size))
elif choice == 2:
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))
elif choice == 3:
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size // 2))
elif choice == 4:
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size // 2, size))
def random_dilate(seg, min=3, max=10):
size = np.random.randint(min, max)
kernel = get_random_structure(size)
seg = cv2.dilate(seg, kernel, iterations=1)
return seg
def random_erode(seg, min=3, max=10):
size = np.random.randint(min, max)
kernel = get_random_structure(size)
seg = cv2.erode(seg, kernel, iterations=1)
return seg
def compute_iou(seg, gt):
intersection = seg * gt
union = seg + gt
return (np.count_nonzero(intersection) + 1e-6) / (
np.count_nonzero(union) + 1e-6)
def select_max_region(mask):
nums, labels, stats, centroids = cv2.connectedComponentsWithStats(
mask, connectivity=8)
background = 0
for row in range(stats.shape[0]):
if stats[row, :][0] == 0 and stats[row, :][1] == 0:
background = row
stats_no_bg = np.delete(stats, background, axis=0)
max_idx = stats_no_bg[:, 4].argmax()
max_region = np.where(labels == max_idx + 1, 1, 0)
return max_region.astype(np.uint8)
def perturb_mask(gt, min_iou=0.3, max_iou=0.99):
iou_target = np.random.uniform(min_iou, max_iou)
h, w = gt.shape
gt = gt.astype(np.uint8)
seg = gt.copy()
# Rare case
if h <= 2 or w <= 2:
print('GT too small, returning original')
return seg
# Do a bunch of random operations
for _ in range(250):
for _ in range(4):
lx, ly = np.random.randint(w), np.random.randint(h)
lw, lh = np.random.randint(lx + 1, w + 1), np.random.randint(
ly + 1, h + 1)
# Randomly set one pixel to 1/0. With the following dilate/erode, we can create holes/external regions
if np.random.rand() < 0.1:
cx = int((lx + lw) / 2)
cy = int((ly + lh) / 2)
seg[cy, cx] = np.random.randint(2) * 255
# Dilate/erode
if np.random.rand() < 0.5:
seg[ly:lh, lx:lw] = random_dilate(seg[ly:lh, lx:lw])
else:
seg[ly:lh, lx:lw] = random_erode(seg[ly:lh, lx:lw])
seg = np.logical_or(seg, gt).astype(np.uint8)
# seg = select_max_region(seg)
if compute_iou(seg, gt) < iou_target:
break
seg = select_max_region(seg.astype(np.uint8))
return seg.astype(np.uint8)
def q_x(x_0, t=65):
'''Adding noise for and given image.'''
x_0 = torch.from_numpy(x_0).float() / 127.5 - 1
num_steps = 100
betas = torch.linspace(-6, 6, num_steps)
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
alphas = 1 - betas
alphas_prod = torch.cumprod(alphas, 0)
alphas_bar_sqrt = torch.sqrt(alphas_prod)
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
noise = torch.randn_like(x_0)
alphas_t = alphas_bar_sqrt[t]
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
return (alphas_t * x_0 + alphas_1_m_t * noise).numpy() * 127.5 + 127.5
def extract_target_boundary(img, target_mask):
Ksize = 3
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=Ksize)
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=Ksize)
# sobel-x
sobel_X = cv2.convertScaleAbs(sobelx)
# sobel-y
sobel_Y = cv2.convertScaleAbs(sobely)
# sobel-xy
scharr = cv2.addWeighted(sobel_X, 0.5, sobel_Y, 0.5, 0)
scharr = np.max(scharr, -1).astype(np.float32) / 255
scharr = scharr * target_mask.astype(np.float32)
return scharr

View File

@@ -0,0 +1,12 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .attention import MemEffAttention
from .block import NestedTensorBlock
from .dino_head import DINOHead
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused

View File

@@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor, nn
logger = logging.getLogger('dinov2')
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning('xFormers not available')
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, 'xFormers is required for nested tensors usage'
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
if attn_bias is not None:
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
else:
self_att_op = None
x = memory_efficient_attention(
q, k, v, attn_bias=attn_bias, op=self_att_op)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@@ -0,0 +1,286 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Any, Callable, Dict, List, Tuple
import torch
from torch import Tensor, nn
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger('dinov2')
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning('xFormers not available')
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(
dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(
drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(
x_flat,
0,
brange,
residual.to(dtype=x.dtype),
alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x,
brange,
residual,
residual_scale_factor,
scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(
x_flat,
0,
brange,
residual.to(dtype=x.dtype),
alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x,
brange,
residual.to(dtype=x.dtype),
scaling=scaling_vector,
alpha=residual_scale_factor)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges
] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list],
branges).view(1, -1,
x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio)
for x in x_list
]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(
x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(
x_list, branges, residual_list, residual_scale_factors):
outputs.append(
add_residual(x, brange, residual, residual_scale_factor,
scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(
self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(
self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, 'Please install xFormers for nested tensors usage'
return self.forward_nested(x_or_x_list)
else:
raise AssertionError

View File

@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn.init import trunc_normal_
from torch.nn.utils import weight_norm
class DINOHead(nn.Module):
def __init__(
self,
in_dim,
out_dim,
use_bn=False,
nlayers=3,
hidden_dim=2048,
bottleneck_dim=256,
mlp_bias=True,
):
super().__init__()
nlayers = max(nlayers, 1)
self.mlp = _build_mlp(
nlayers,
in_dim,
bottleneck_dim,
hidden_dim=hidden_dim,
use_bn=use_bn,
bias=mlp_bias)
self.apply(self._init_weights)
self.last_layer = weight_norm(
nn.Linear(bottleneck_dim, out_dim, bias=False))
self.last_layer.weight_g.data.fill_(1)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x):
x = self.mlp(x)
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
x = self.last_layer(x)
return x
def _build_mlp(nlayers,
in_dim,
bottleneck_dim,
hidden_dim=None,
use_bn=False,
bias=True):
if nlayers == 1:
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
else:
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
for _ in range(nlayers - 2):
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
if use_bn:
layers.append(nn.BatchNorm1d(hidden_dim))
layers.append(nn.GELU())
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
return nn.Sequential(*layers)

View File

@@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@@ -0,0 +1,26 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import torch
from torch import Tensor, nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

View File

@@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
import torch.nn as nn
from torch import Tensor
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f'Input image height {H} is not a multiple of patch height {patch_H}'
assert W % patch_W == 0, f'Input image width {W} is not a multiple of patch width: {patch_W}'
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (
self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

View File

@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import torch.nn.functional as F
from torch import Tensor, nn
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)

View File

@@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import logging
from . import vision_transformer as vits
logger = logging.getLogger('dinov2')
def build_model(args, only_teacher=False, img_size=224):
args.arch = args.arch.removesuffix('_memeff')
if 'vit' in args.arch:
vit_kwargs = dict(
img_size=img_size,
patch_size=args.patch_size,
init_values=args.layerscale,
ffn_layer=args.ffn_layer,
block_chunks=args.block_chunks,
qkv_bias=args.qkv_bias,
proj_bias=args.proj_bias,
ffn_bias=args.ffn_bias,
)
teacher = vits.__dict__[args.arch](**vit_kwargs)
if only_teacher:
return teacher, teacher.embed_dim
student = vits.__dict__[args.arch](
**vit_kwargs,
drop_path_rate=args.drop_path_rate,
drop_path_uniform=args.drop_path_uniform,
)
embed_dim = student.embed_dim
return student, teacher, embed_dim
def build_model_from_cfg(cfg, only_teacher=False):
return build_model(
cfg.student,
only_teacher=only_teacher,
img_size=cfg.crops.global_crops_size)

View File

@@ -0,0 +1,390 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
import math
from functools import partial
from typing import Callable, Sequence, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from ..layers import MemEffAttention, Mlp
from ..layers import NestedTensorBlock as Block
from ..layers import PatchEmbed, SwiGLUFFNFused
logger = logging.getLogger('dinov2')
def named_apply(fn: Callable,
module: nn.Module,
name='',
depth_first=True,
include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = '.'.join((name, child_name)) if name else child_name
named_apply(
fn=fn,
module=child_module,
name=child_name,
depth_first=depth_first,
include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer='mlp',
block_chunks=1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.patch_embed = embed_layer(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + self.num_tokens, embed_dim))
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
if ffn_layer == 'mlp':
logger.info('using MLP layer as FFN')
ffn_layer = Mlp
elif ffn_layer == 'swiglufused' or ffn_layer == 'swiglu':
logger.info('using SwiGLU layer as FFN')
ffn_layer = SwiGLUFFNFused
elif ffn_layer == 'identity':
logger.info('using Identity layer as FFN')
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
) for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i
+ blocks_list[i:i + chunksize])
self.blocks = nn.ModuleList(
[BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
w0, h0 = w0 + 0.1, h0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
dim).permute(0, 3, 1, 2),
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
mode='bicubic',
)
assert int(w0) == patch_pos_embed.shape[-2] and int(
h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed),
dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(
masks.unsqueeze(-1),
self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
return x
def forward_features_list(self, x_list, masks_list):
x = [
self.prepare_tokens_with_masks(x, masks)
for x, masks in zip(x_list, masks_list)
]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append({
'x_norm_clstoken': x_norm[:, 0],
'x_norm_patchtokens': x_norm[:, 1:],
'x_prenorm': x,
'masks': masks,
})
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
'x_norm_clstoken': x_norm[:, 0],
'x_norm_patchtokens': x_norm[:, 1:],
'x_prenorm': x,
'masks': masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n,
total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(
blocks_to_take
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n,
total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(
blocks_to_take
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True,
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size,
-1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret['x_norm_clstoken'])
def init_weights_vit_timm(module: nn.Module, name: str = ''):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model
def vit_base(patch_size=16, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model
def vit_large(patch_size=16, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model
def vit_giant2(patch_size=16, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
**kwargs,
)
return model

View File

@@ -0,0 +1,195 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
dependencies = ['torch']
_DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2'
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
compact_arch_name = arch_name.replace('_', '')[:4]
return f'dinov2_{compact_arch_name}{patch_size}'
def _make_dinov2_model(
*,
arch_name: str = 'vit_large',
img_size: int = 518,
patch_size: int = 14,
init_values: float = 1.0,
ffn_layer: str = 'mlp',
block_chunks: int = 0,
pretrained: bool = True,
**kwargs,
):
from .dinov2.models import vision_transformer as vits
_ = _make_dinov2_model_name(arch_name, patch_size)
vit_kwargs = dict(
img_size=img_size,
patch_size=patch_size,
init_values=init_values,
ffn_layer=ffn_layer,
block_chunks=block_chunks,
)
vit_kwargs.update(**kwargs)
model = vits.__dict__[arch_name](**vit_kwargs)
# if pretrained:
# state_dict = torch.load('')
# model.load_state_dict(state_dict, strict=False)
return model
def dinov2_vits14(*, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name='vit_small', pretrained=pretrained, **kwargs)
def dinov2_vitb14(*, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-B/14 model pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name='vit_base', pretrained=pretrained, **kwargs)
def dinov2_vitl14(*, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name='vit_large', pretrained=pretrained, **kwargs)
def dinov2_vitg14(*, pretrained: bool = True, **kwargs):
"""
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
"""
return _make_dinov2_model(
arch_name='vit_giant2',
ffn_layer='swiglufused',
pretrained=pretrained,
**kwargs)
def _make_dinov2_linear_head(
*,
model_name: str = 'dinov2_vitl14',
embed_dim: int = 1024,
layers: int = 4,
pretrained: bool = True,
**kwargs,
):
assert layers in (1, 4), f'Unsupported number of layers: {layers}'
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
if pretrained:
layers_str = str(layers) if layers == 4 else ''
url = _DINOV2_BASE_URL + f'/{model_name}/{model_name}_linear{layers_str}_head.pth'
state_dict = torch.hub.load_state_dict_from_url(
url, map_location='cpu')
linear_head.load_state_dict(state_dict, strict=False)
return linear_head
class _LinearClassifierWrapper(nn.Module):
def __init__(self,
*,
backbone: nn.Module,
linear_head: nn.Module,
layers: int = 4):
super().__init__()
self.backbone = backbone
self.linear_head = linear_head
self.layers = layers
def forward(self, x):
if self.layers == 1:
x = self.backbone.forward_features(x)
cls_token = x['x_norm_clstoken'].squeeze(0)
patch_tokens = x['x_norm_patchtokens'].squeeze(0)
linear_input = torch.cat([cls_token, patch_tokens.mean(0)])
elif self.layers == 4:
x = self.backbone.get_intermediate_layers(
x, n=4, return_class_token=True)
linear_input = torch.cat([
x[0][1].squeeze(0), x[1][1].squeeze(0), x[2][1].squeeze(0),
x[3][1].squeeze(0), x[3][0].squeeze(0).mean(0)
])
else:
assert False, f'Unsupported number of layers: {self.layers}'
return self.linear_head(linear_input)
def _make_dinov2_linear_classifier(
*,
arch_name: str = 'vit_large',
layers: int = 4,
pretrained: bool = True,
**kwargs,
):
backbone = _make_dinov2_model(
arch_name=arch_name, pretrained=pretrained, **kwargs)
embed_dim = backbone.embed_dim
patch_size = backbone.patch_size
model_name = _make_dinov2_model_name(arch_name, patch_size)
linear_head = _make_dinov2_linear_head(
model_name=model_name,
embed_dim=embed_dim,
layers=layers,
pretrained=pretrained)
return _LinearClassifierWrapper(
backbone=backbone, linear_head=linear_head, layers=layers)
def dinov2_vits14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally)
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name='vit_small', layers=layers, pretrained=pretrained, **kwargs)
def dinov2_vitb14_lc(*, pretrained: bool = True, **kwargs):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally)
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name='vit_base', pretrained=pretrained, **kwargs)
def dinov2_vitl14_lc(*, pretrained: bool = True, **kwargs):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally)
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name='vit_large', pretrained=pretrained, **kwargs)
def dinov2_vitg14_lc(*, pretrained: bool = True, **kwargs):
"""
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally)
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
"""
return _make_dinov2_linear_classifier(
arch_name='vit_giant2',
ffn_layer='swiglufused',
pretrained=pretrained,
**kwargs)

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,274 @@
from contextlib import contextmanager
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from ...ldm.modules.diffusionmodules.model import Decoder, Encoder
from ...ldm.modules.distributions.distributions import \
DiagonalGaussianDistribution
from ...ldm.modules.ema import LitEma
from ...ldm.util import instantiate_from_config
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key='image',
colorize_nlabels=None,
monitor=None,
ema_decay=None,
learn_logvar=False):
super().__init__()
self.learn_logvar = learn_logvar
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig['double_z']
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
2 * embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
ddconfig['z_channels'], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels) == int
self.register_buffer('colorize',
torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.use_ema = ema_decay is not None
if self.use_ema:
self.ema_decay = ema_decay
assert 0. < ema_decay < 1.
self.model_ema = LitEma(self, decay=ema_decay)
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location='cpu')['state_dict']
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print('Deleting key {} from state_dict.'.format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f'Restored from {path}')
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f'{context}: Switched to EMA weights')
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f'{context}: Restored training weights')
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1,
2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'aeloss',
aeloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_ae,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
optimizer_idx,
self.global_step,
last_layer=self.get_last_layer(),
split='train')
self.log(
'discloss',
discloss,
prog_bar=True,
logger=True,
on_step=True,
on_epoch=True)
self.log_dict(
log_dict_disc,
prog_bar=False,
logger=True,
on_step=True,
on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
_ = self._validation_step(batch, batch_idx, postfix='_ema')
return log_dict
def _validation_step(self, batch, batch_idx, postfix=''):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(
inputs,
reconstructions,
posterior,
0,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + postfix)
discloss, log_dict_disc = self.loss(
inputs,
reconstructions,
posterior,
1,
self.global_step,
last_layer=self.get_last_layer(),
split='val' + postfix)
self.log(f'val{postfix}/rec_loss',
log_dict_ae[f'val{postfix}/rec_loss'])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
ae_params_list = list(self.encoder.parameters()) + list(
self.decoder.parameters()) + list(
self.quant_conv.parameters()) + list(
self.post_quant_conv.parameters())
if self.learn_logvar:
print(f'{self.__class__.__name__}: Learning logvar')
ae_params_list.append(self.loss.logvar)
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
log['reconstructions'] = xrec
if log_ema or self.use_ema:
with self.ema_scope():
xrec_ema, posterior_ema = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec_ema.shape[1] > 3
xrec_ema = self.to_rgb(xrec_ema)
log['samples_ema'] = self.decode(
torch.randn_like(posterior_ema.sample()))
log['reconstructions_ema'] = xrec_ema
log['inputs'] = x
return log
def to_rgb(self, x):
assert self.image_key == 'segmentation'
if not hasattr(self, 'colorize'):
self.register_buffer('colorize',
torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,446 @@
"""SAMPLING ONLY."""
import numpy as np
import torch
from tqdm import tqdm
from ....ldm.modules.diffusionmodules.util import (
extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps,
noise_like)
class DDIMSampler(object):
def __init__(self, model, schedule='linear', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device('cuda'):
attr = attr.to(torch.device('cuda'))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.,
verbose=True):
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None,
**kwargs):
if conditioning is not None:
if isinstance(conditioning, dict):
ctmp = conditioning[list(conditioning.keys())[0]]
while isinstance(ctmp, list):
ctmp = ctmp[0]
cbs = ctmp.shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
elif isinstance(conditioning, list):
for ctmp in conditioning:
if ctmp.shape[0] != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
samples, intermediates = self.ddim_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
ucg_schedule=ucg_schedule)
return samples, intermediates
@torch.no_grad()
def ddim_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None,
ucg_schedule=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = reversed(range(
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
if ucg_schedule is not None:
assert len(ucg_schedule) == len(time_range)
unconditional_guidance_scale = ucg_schedule[i]
outs = self.p_sample_ddim(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold)
img, pred_x0 = outs
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_ddim(self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
model_output = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
if isinstance(c, dict):
assert isinstance(unconditional_conditioning, dict)
c_in = dict()
for k in c:
if isinstance(c[k], list):
c_in[k] = [
torch.cat(
[unconditional_conditioning[k][i], c[k][i]])
for i in range(len(c[k]))
]
else:
c_in[k] = torch.cat(
[unconditional_conditioning[k], c[k]])
elif isinstance(c, list):
c_in = list()
assert isinstance(unconditional_conditioning, list)
for i in range(len(c)):
c_in.append(
torch.cat([unconditional_conditioning[i], c[i]]))
else:
c_in = torch.cat([unconditional_conditioning, c])
model_uncond, model_t = self.model.apply_model(x_in, t_in,
c_in).chunk(2)
model_output = model_uncond + unconditional_guidance_scale * (
model_t - model_uncond)
if self.model.parameterization == 'v':
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
else:
e_t = model_output
if score_corrector is not None:
assert self.model.parameterization == 'eps', 'not implemented'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
sqrt_one_minus_alphas[index],
device=device)
# current prediction for x_0
if self.model.parameterization != 'v':
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
else:
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
raise NotImplementedError()
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
@torch.no_grad()
def encode(self,
x0,
c,
t_enc,
use_original_steps=False,
return_intermediates=None,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
callback=None):
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[
0]
assert t_enc <= num_reference_steps
num_steps = t_enc
if use_original_steps:
alphas_next = self.alphas_cumprod[:num_steps]
alphas = self.alphas_cumprod_prev[:num_steps]
else:
alphas_next = self.ddim_alphas[:num_steps]
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
x_next = x0
intermediates = []
inter_steps = []
for i in tqdm(range(num_steps), desc='Encoding Image'):
t = torch.full((x0.shape[0], ),
i,
device=self.model.device,
dtype=torch.long)
if unconditional_guidance_scale == 1.:
noise_pred = self.model.apply_model(x_next, t, c)
else:
assert unconditional_conditioning is not None
e_t_uncond, noise_pred = torch.chunk(
self.model.apply_model(
torch.cat((x_next, x_next)), torch.cat((t, t)),
torch.cat((unconditional_conditioning, c))), 2)
tmp = noise_pred - e_t_uncond
noise_pred = e_t_uncond + unconditional_guidance_scale * tmp
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
tmp = (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()
weighted_noise_pred = alphas_next[i].sqrt() * tmp * noise_pred
x_next = xt_weighted + weighted_noise_pred
if return_intermediates and i % (num_steps // return_intermediates
) == 0 and i < num_steps - 1:
intermediates.append(x_next)
inter_steps.append(i)
elif return_intermediates and i >= num_steps - 2:
intermediates.append(x_next)
inter_steps.append(i)
if callback:
callback(i)
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
if return_intermediates:
out.update({'intermediates': intermediates})
return x_next, out
@torch.no_grad()
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
# fast, but does not allow for exact reconstruction
# t serves as an index to gather the correct alphas
if use_original_steps:
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
else:
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
if noise is None:
noise = torch.randn_like(x0)
return (
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
* noise)
@torch.no_grad()
def decode(self,
x_latent,
cond,
t_start,
unconditional_guidance_scale=1.0,
unconditional_conditioning=None,
use_original_steps=False,
callback=None):
timesteps = np.arange(self.ddpm_num_timesteps
) if use_original_steps else self.ddim_timesteps
timesteps = timesteps[:t_start]
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
print(f'Running DDIM Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
x_dec = x_latent
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((x_latent.shape[0], ),
step,
device=x_latent.device,
dtype=torch.long)
x_dec, _ = self.p_sample_ddim(
x_dec,
cond,
ts,
index=index,
use_original_steps=use_original_steps,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning)
if callback:
callback(i)
return x_dec

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,328 @@
"""SAMPLING ONLY."""
from functools import partial
import numpy as np
import torch
from tqdm import tqdm
from ....ldm.models.diffusion.sampling_util import norm_thresholding
from ....ldm.modules.diffusionmodules.util import (
make_ddim_sampling_parameters, make_ddim_timesteps, noise_like)
class PLMSSampler(object):
def __init__(self, model, schedule='linear', **kwargs):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.schedule = schedule
def register_buffer(self, name, attr):
if type(attr) == torch.Tensor:
if attr.device != torch.device('cuda'):
attr = attr.to(torch.device('cuda'))
setattr(self, name, attr)
def make_schedule(self,
ddim_num_steps,
ddim_discretize='uniform',
ddim_eta=0.,
verbose=True):
if ddim_eta != 0:
raise ValueError('ddim_eta must be 0 for PLMS')
self.ddim_timesteps = make_ddim_timesteps(
ddim_discr_method=ddim_discretize,
num_ddim_timesteps=ddim_num_steps,
num_ddpm_timesteps=self.ddpm_num_timesteps,
verbose=verbose)
alphas_cumprod = self.model.alphas_cumprod
assert alphas_cumprod.shape[
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
def to_torch(x):
return x.clone().detach().to(torch.float32).to(self.model.device)
self.register_buffer('betas', to_torch(self.model.betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(self.model.alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod.cpu())))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod.cpu())))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
# ddim sampling parameters
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
alphacums=alphas_cumprod.cpu(),
ddim_timesteps=self.ddim_timesteps,
eta=ddim_eta,
verbose=verbose)
self.register_buffer('ddim_sigmas', ddim_sigmas)
self.register_buffer('ddim_alphas', ddim_alphas)
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
self.register_buffer('ddim_sqrt_one_minus_alphas',
np.sqrt(1. - ddim_alphas))
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
self.register_buffer('ddim_sigmas_for_original_num_steps',
sigmas_for_original_sampling_steps)
@torch.no_grad()
def sample(
self,
S,
batch_size,
shape,
conditioning=None,
callback=None,
normals_sequence=None,
img_callback=None,
quantize_x0=False,
eta=0.,
mask=None,
x0=None,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
verbose=True,
x_T=None,
log_every_t=100,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
dynamic_threshold=None,
**kwargs):
if conditioning is not None:
if isinstance(conditioning, dict):
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
if cbs != batch_size:
print(
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
)
else:
if conditioning.shape[0] != batch_size:
print(
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
)
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
# sampling
C, H, W = shape
size = (batch_size, C, H, W)
print(f'Data shape for PLMS sampling is {size}')
samples, intermediates = self.plms_sampling(
conditioning,
size,
callback=callback,
img_callback=img_callback,
quantize_denoised=quantize_x0,
mask=mask,
x0=x0,
ddim_use_original_steps=False,
noise_dropout=noise_dropout,
temperature=temperature,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
x_T=x_T,
log_every_t=log_every_t,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
dynamic_threshold=dynamic_threshold,
)
return samples, intermediates
@torch.no_grad()
def plms_sampling(self,
cond,
shape,
x_T=None,
ddim_use_original_steps=False,
callback=None,
timesteps=None,
quantize_denoised=False,
mask=None,
x0=None,
img_callback=None,
log_every_t=100,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
dynamic_threshold=None):
device = self.model.betas.device
b = shape[0]
if x_T is None:
img = torch.randn(shape, device=device)
else:
img = x_T
if timesteps is None:
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
elif timesteps is not None and not ddim_use_original_steps:
subset_end = int(
min(timesteps / self.ddim_timesteps.shape[0], 1)
* self.ddim_timesteps.shape[0]) - 1
timesteps = self.ddim_timesteps[:subset_end]
intermediates = {'x_inter': [img], 'pred_x0': [img]}
time_range = list(reversed(range(
0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
0]
print(f'Running PLMS Sampling with {total_steps} timesteps')
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
old_eps = []
for i, step in enumerate(iterator):
index = total_steps - i - 1
ts = torch.full((b, ), step, device=device, dtype=torch.long)
ts_next = torch.full((b, ),
time_range[min(i + 1,
len(time_range) - 1)],
device=device,
dtype=torch.long)
if mask is not None:
assert x0 is not None
img_orig = self.model.q_sample(
x0, ts) # TODO: deterministic forward pass?
img = img_orig * mask + (1. - mask) * img
outs = self.p_sample_plms(
img,
cond,
ts,
index=index,
use_original_steps=ddim_use_original_steps,
quantize_denoised=quantize_denoised,
temperature=temperature,
noise_dropout=noise_dropout,
score_corrector=score_corrector,
corrector_kwargs=corrector_kwargs,
unconditional_guidance_scale=unconditional_guidance_scale,
unconditional_conditioning=unconditional_conditioning,
old_eps=old_eps,
t_next=ts_next,
dynamic_threshold=dynamic_threshold)
img, pred_x0, e_t = outs
old_eps.append(e_t)
if len(old_eps) >= 4:
old_eps.pop(0)
if callback:
callback(i)
if img_callback:
img_callback(pred_x0, i)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(img)
intermediates['pred_x0'].append(pred_x0)
return img, intermediates
@torch.no_grad()
def p_sample_plms(self,
x,
c,
t,
index,
repeat_noise=False,
use_original_steps=False,
quantize_denoised=False,
temperature=1.,
noise_dropout=0.,
score_corrector=None,
corrector_kwargs=None,
unconditional_guidance_scale=1.,
unconditional_conditioning=None,
old_eps=None,
t_next=None,
dynamic_threshold=None):
b, *_, device = *x.shape, x.device
def get_model_output(x, t):
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
e_t = self.model.apply_model(x, t, c)
else:
x_in = torch.cat([x] * 2)
t_in = torch.cat([t] * 2)
c_in = torch.cat([unconditional_conditioning, c])
e_t_uncond, e_t = self.model.apply_model(x_in, t_in,
c_in).chunk(2)
e_t = e_t_uncond + unconditional_guidance_scale * (
e_t - e_t_uncond)
if score_corrector is not None:
assert self.model.parameterization == 'eps'
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
**corrector_kwargs)
return e_t
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
if use_original_steps else self.ddim_sqrt_one_minus_alphas
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
def get_x_prev_and_pred_x0(e_t, index):
# select parameters corresponding to the currently considered timestep
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
a_prev = torch.full((b, 1, 1, 1),
alphas_prev[index],
device=device)
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
sqrt_one_minus_alphas[index],
device=device)
# current prediction for x_0
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
if quantize_denoised:
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
if dynamic_threshold is not None:
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
# direction pointing to x_t
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
noise = sigma_t * noise_like(x.shape, device,
repeat_noise) * temperature
if noise_dropout > 0.:
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
return x_prev, pred_x0
e_t = get_model_output(x, t)
if len(old_eps) == 0:
# Pseudo Improved Euler (2nd order)
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
e_t_next = get_model_output(x_prev, t_next)
e_t_prime = (e_t + e_t_next) / 2
elif len(old_eps) == 1:
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (3 * e_t - old_eps[-1]) / 2
elif len(old_eps) == 2:
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
elif len(old_eps) >= 3:
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2]
- 9 * old_eps[-3]) / 24
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
return x_prev, pred_x0, e_t

View File

@@ -0,0 +1,25 @@
import numpy as np
import torch
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
)
return x[(..., ) + (None, ) * dims_to_append]
def norm_thresholding(x0, value):
s = append_dims(
x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
return x0 * (value / s)
def spatial_norm_thresholding(x0, value):
# b c h w
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
return x0 * (value / s)

View File

@@ -0,0 +1,367 @@
import math
# CrossAttn precision handling
import os
from inspect import isfunction
from typing import Any, Optional
import torch
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum, nn
from ...ldm.modules.diffusionmodules.util import checkpoint
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except Exception:
XFORMERS_IS_AVAILBLE = False
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
def exists(val):
return val is not None
def uniq(arr):
return {el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(nn.Linear(
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x + h_
class CrossAttention(nn.Module):
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head**-0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
(q, k, v))
# force cast to fp32 to avoid overflowing
if _ATTN_PRECISION == 'fp32':
with torch.autocast(enabled=False, device_type='cuda'):
q, k = q.float(), k.float()
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
else:
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
del q, k
if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
sim = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', sim, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class MemoryEfficientCrossAttention(nn.Module):
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
def __init__(self,
query_dim,
context_dim=None,
heads=8,
dim_head=64,
dropout=0.0):
super().__init__()
print(
f'Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using '
f'{heads} heads.')
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
self.attention_op: Optional[Any] = None
def forward(self, x, context=None, mask=None):
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
b, _, _ = q.shape
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(b, t.shape[
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
b * self.heads, t.shape[1], self.dim_head).contiguous(),
(q, k, v),
)
# actually compute the attention, what we cannot get enough of
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
if exists(mask):
raise NotImplementedError
out = (
out.unsqueeze(0).reshape(
b, self.heads, out.shape[1],
self.dim_head).permute(0, 2, 1,
3).reshape(b, out.shape[1],
self.heads * self.dim_head))
return self.to_out(out)
class BasicTransformerBlock(nn.Module):
ATTENTION_MODES = {
'softmax': CrossAttention, # vanilla attention
'softmax-xformers': MemoryEfficientCrossAttention
}
def __init__(self,
dim,
n_heads,
d_head,
dropout=0.,
context_dim=None,
gated_ff=True,
checkpoint=True,
disable_self_attn=False):
super().__init__()
attn_mode = 'softmax-xformers' if XFORMERS_IS_AVAILBLE else 'softmax'
assert attn_mode in self.ATTENTION_MODES
attn_cls = self.ATTENTION_MODES[attn_mode]
self.disable_self_attn = disable_self_attn
self.attn1 = attn_cls(
query_dim=dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout,
context_dim=context_dim if self.disable_self_attn else
None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = attn_cls(
query_dim=dim,
context_dim=context_dim,
heads=n_heads,
dim_head=d_head,
dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(),
self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(
self.norm1(x),
context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
NEW: use_linear for more efficiency instead of the 1x1 convs
"""
def __init__(self,
in_channels,
n_heads,
d_head,
depth=1,
dropout=0.,
context_dim=None,
disable_self_attn=False,
use_linear=False,
use_checkpoint=True):
super().__init__()
if exists(context_dim) and not isinstance(context_dim, list):
context_dim = [context_dim]
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
if not use_linear:
self.proj_in = nn.Conv2d(
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
else:
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(
inner_dim,
n_heads,
d_head,
dropout=dropout,
context_dim=context_dim[d],
disable_self_attn=disable_self_attn,
checkpoint=use_checkpoint) for d in range(depth)
])
if not use_linear:
self.proj_out = zero_module(
nn.Conv2d(
inner_dim, in_channels, kernel_size=1, stride=1,
padding=0))
else:
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
self.use_linear = use_linear
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
if not isinstance(context, list):
context = [context]
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
if not self.use_linear:
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
if self.use_linear:
x = self.proj_in(x)
for i, block in enumerate(self.transformer_blocks):
x = block(x, context=context[i])
if self.use_linear:
x = self.proj_out(x)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
if not self.use_linear:
x = self.proj_out(x)
return x + x_in

View File

@@ -0,0 +1,966 @@
# pytorch_diffusion + derived encoder decoder
import math
from typing import Any, Optional
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from ....ldm.modules.attention import MemoryEfficientCrossAttention
try:
import xformers
import xformers.ops
XFORMERS_IS_AVAILBLE = True
except Exception:
XFORMERS_IS_AVAILBLE = False
print("No module 'xformers'. Proceeding without it.")
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
return emb
def nonlinearity(x):
# swish
return x * torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(
x, scale_factor=2.0, mode='nearest')
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x):
if self.with_conv:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout,
temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + h
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b, c, h, w = q.shape
q = q.reshape(b, c, h * w)
q = q.permute(0, 2, 1) # b,hw,c
k = k.reshape(b, c, h * w) # b,c,hw
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b, c, h * w)
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b, c, h, w)
h_ = self.proj_out(h_)
return x + h_
class MemoryEfficientAttnBlock(nn.Module):
"""
Uses xformers efficient implementation,
Note: this is a single-head self-attention operation
"""
#
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.attention_op: Optional[Any] = None
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
B, C, H, W = q.shape
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'),
(q, k, v))
q, k, v = map(
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(
0, 2, 1, 3).reshape(B * 1, t.shape[1], C).contiguous(),
(q, k, v),
)
out = xformers.ops.memory_efficient_attention(
q, k, v, attn_bias=None, op=self.attention_op)
out = (
out.unsqueeze(0).reshape(B, 1, out.shape[1],
C).permute(0, 2, 1,
3).reshape(B, out.shape[1], C))
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
out = self.proj_out(out)
return x + out
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
def forward(self, x, context=None, mask=None):
b, c, h, w = x.shape
x = rearrange(x, 'b c h w -> b (h w) c')
out = super().forward(x, context=context, mask=mask)
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
return x + out
def make_attn(in_channels, attn_type='vanilla', attn_kwargs=None):
assert attn_type in [
'vanilla', 'vanilla-xformers', 'memory-efficient-cross-attn', 'linear',
'none'
], f'attn_type {attn_type} unknown'
if XFORMERS_IS_AVAILBLE and attn_type == 'vanilla':
attn_type = 'vanilla-xformers'
print(
f"making attention of type '{attn_type}' with {in_channels} in_channels"
)
if attn_type == 'vanilla':
assert attn_kwargs is None
return AttnBlock(in_channels)
elif attn_type == 'vanilla-xformers':
print(
f'building MemoryEfficientAttnBlock with {in_channels} in_channels...'
)
return MemoryEfficientAttnBlock(in_channels)
elif type == 'memory-efficient-cross-attn':
attn_kwargs['query_dim'] = in_channels
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
elif attn_type == 'none':
return nn.Identity(in_channels)
else:
raise NotImplementedError()
class Model(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
use_timestep=True,
use_linear_attn=False,
attn_type='vanilla'):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = self.ch * 4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch, self.temb_ch),
torch.nn.Linear(self.temb_ch, self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
skip_in = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
if i_block == self.num_res_blocks:
skip_in = ch * in_ch_mult[i_level]
block.append(
ResnetBlock(
in_channels=block_in + skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, x, t=None, context=None):
# assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
double_z=True,
use_linear_attn=False,
attn_type='vanilla',
**ignore_kwargs):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
curr_res = resolution
in_ch_mult = (1, ) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions - 1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self,
*,
ch,
out_ch,
ch_mult=(1, 2, 4, 8),
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
in_channels,
resolution,
z_channels,
give_pre_end=False,
tanh_out=False,
use_linear_attn=False,
attn_type='vanilla',
**ignorekwargs):
super().__init__()
if use_linear_attn:
attn_type = 'linear'
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2**(self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
print('Working with z of shape {} = {} dimensions.'.format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_ch, kernel_size=3, stride=1, padding=1)
def forward(self, z):
# assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([
nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(
in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0),
ResnetBlock(
in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0,
dropout=0.0),
ResnetBlock(
in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0,
dropout=0.0),
nn.Conv2d(2 * in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)
])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1, 2, 3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self,
in_channels,
out_channels,
ch,
num_res_blocks,
resolution,
ch_mult=(2, 2),
dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2**(self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(
ResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(
block_in, out_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self,
factor,
in_channels,
mid_channels,
out_channels,
depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(
in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
self.res_block1 = nn.ModuleList([
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)
])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([
ResnetBlock(
in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)
])
self.conv_out = nn.Conv2d(
mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(
x,
size=(int(round(x.shape[2] * self.factor)),
int(round(x.shape[3] * self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self,
in_channels,
ch,
resolution,
out_ch,
num_res_blocks,
attn_resolutions,
dropout=0.0,
resamp_with_conv=True,
ch_mult=(1, 2, 4, 8),
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(
in_channels=in_channels,
num_res_blocks=num_res_blocks,
ch=ch,
ch_mult=ch_mult,
z_channels=intermediate_chn,
double_z=False,
resolution=resolution,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=intermediate_chn,
mid_channels=intermediate_chn,
out_channels=out_ch,
depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self,
z_channels,
out_ch,
resolution,
num_res_blocks,
attn_resolutions,
ch,
ch_mult=(1, 2, 4, 8),
dropout=0.0,
resamp_with_conv=True,
rescale_factor=1.0,
rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels * ch_mult[-1]
self.decoder = Decoder(
out_ch=out_ch,
z_channels=tmp_chn,
attn_resolutions=attn_resolutions,
dropout=dropout,
resamp_with_conv=resamp_with_conv,
in_channels=None,
num_res_blocks=num_res_blocks,
ch_mult=ch_mult,
resolution=resolution,
ch=ch)
self.rescaler = LatentRescaler(
factor=rescale_factor,
in_channels=z_channels,
mid_channels=tmp_chn,
out_channels=tmp_chn,
depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self,
in_size,
out_size,
in_channels,
out_channels,
ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size // in_size)) + 1
factor_up = 1. + (out_size % in_size)
print(
f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}'
)
self.rescaler = LatentRescaler(
factor=factor_up,
in_channels=in_channels,
mid_channels=2 * in_channels,
out_channels=in_channels)
self.decoder = Decoder(
out_ch=out_channels,
resolution=out_size,
z_channels=in_channels,
num_res_blocks=2,
attn_resolutions=[],
in_channels=None,
ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode='bilinear'):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(
f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode'
)
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(
in_channels, in_channels, kernel_size=4, stride=2, padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor == 1.0:
return x
else:
x = torch.nn.functional.interpolate(
x,
mode=self.mode,
align_corners=False,
scale_factor=scale_factor)
return x

View File

@@ -0,0 +1,820 @@
import math
from abc import abstractmethod
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from ....ldm.modules.attention import SpatialTransformer
from ....ldm.modules.diffusionmodules.util import (avg_pool_nd, checkpoint,
conv_nd, linear,
normalization,
timestep_embedding,
zero_module)
from ....ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
# go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(
dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
mode='nearest')
else:
x = F.interpolate(x, scale_factor=2, mode='nearest')
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(
self.channels, self.out_channels, kernel_size=ks, stride=2)
def forward(self, x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self,
channels,
use_conv,
dims=2,
out_channels=None,
padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims,
self.channels,
self.out_channels,
3,
stride=stride,
padding=padding)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels
if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(
dims, self.out_channels, self.out_channels, 3, padding=1)),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels,
1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(self._forward, (x, emb), self.parameters(),
self.use_checkpoint)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(
self._forward, (x, ), self.parameters(), True
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
# return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial**2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
'bct,bcs->bts', q * scale,
k * scale) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum('bts,bcs->bct', weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
'bct,bcs->bts',
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum('bts,bcs->bct', weight,
v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None,
disable_middle_self_attn=False,
use_linear_in_transformer=False,
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None
if context_dim is not None:
assert use_spatial_transformer
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError(
'provide num_res_blocks either as an int (globally constant) or '
'as a list/tuple (per-level) with the same length as channel_mult'
)
self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(
map(
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
],
range(len(num_attention_blocks))))
print(
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
f'attention will still not be set.')
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
if isinstance(self.num_classes, int):
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
elif self.num_classes == 'continuous':
print('setting up linear c_adm embedding layer')
self.label_emb = nn.Linear(1, time_embed_dim)
else:
raise ValueError()
self.input_blocks = nn.ModuleList([
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1))
])
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks
) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint))
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
) if resblock_updown else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else
SpatialTransformer( # always uses a self-attn
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disable_middle_self_attn,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
# num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks
) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else
SpatialTransformer(
ch,
num_heads,
dim_head,
depth=transformer_depth,
context_dim=context_dim,
disable_self_attn=disabled_sa,
use_linear=use_linear_in_transformer,
use_checkpoint=use_checkpoint))
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
) if resblock_updown else Upsample(
ch, conv_resample, dims=dims, out_channels=out_ch))
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), 'must specify y if and only if the model is class-conditional'
hs = []
t_emb = timestep_embedding(
timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)

View File

@@ -0,0 +1,103 @@
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from ....ldm.modules.diffusionmodules.util import (extract_into_tensor,
make_beta_schedule)
from ....ldm.util import default
class AbstractLowScaleModel(nn.Module):
# for concatenating a downsampled image to the latent representation
def __init__(self, noise_schedule_config=None):
super(AbstractLowScaleModel, self).__init__()
if noise_schedule_config is not None:
self.register_schedule(**noise_schedule_config)
def register_schedule(self,
beta_schedule='linear',
timesteps=1000,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
betas = make_beta_schedule(
beta_schedule,
timesteps,
linear_start=linear_start,
linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev',
to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod',
to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod',
to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod',
to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod',
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
* x_start
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
x_start.shape) * noise)
def forward(self, x):
return x, None
def decode(self, x):
return x
class SimpleImageConcat(AbstractLowScaleModel):
# no noise level conditioning
def __init__(self):
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
self.max_noise_level = 0
def forward(self, x):
# fix to constant noise level
return x, torch.zeros(x.shape[0], device=x.device).long()
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
def __init__(self,
noise_schedule_config,
max_noise_level=1000,
to_cuda=False):
super().__init__(noise_schedule_config=noise_schedule_config)
self.max_noise_level = max_noise_level
def forward(self, x, noise_level=None):
if noise_level is None:
noise_level = torch.randint(
0, self.max_noise_level, (x.shape[0], ),
device=x.device).long()
else:
assert isinstance(noise_level, torch.Tensor)
z = self.q_sample(x, noise_level)
return z, noise_level

View File

@@ -0,0 +1,310 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import math
import os
import numpy as np
import torch
import torch.nn as nn
from einops import repeat
from ....ldm.util import instantiate_from_config
def make_beta_schedule(schedule,
n_timestep,
linear_start=1e-4,
linear_end=2e-2,
cosine_s=8e-3):
if schedule == 'linear':
betas = (
torch.linspace(
linear_start**0.5,
linear_end**0.5,
n_timestep,
dtype=torch.float64)**2)
elif schedule == 'cosine':
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
+ cosine_s)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == 'sqrt_linear':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == 'sqrt':
betas = torch.linspace(
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method,
num_ddim_timesteps,
num_ddpm_timesteps,
verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
num_ddim_timesteps))**2).astype(int)
else:
raise NotImplementedError(
f'There is no ddim discretization method called "{ddim_discr_method}"'
)
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums,
ddim_timesteps,
eta,
verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]]
+ alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
tmp = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
sigmas = eta * np.sqrt(tmp)
if verbose:
print(
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
)
print(
f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
)
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
ctx.gpu_autocast_kwargs = {
'enabled': torch.is_autocast_enabled(),
'dtype': torch.get_autocast_gpu_dtype(),
'cache_enabled': torch.is_autocast_cache_enabled()
}
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [
x.detach().requires_grad_(True) for x in ctx.input_tensors
]
with torch.enable_grad(), \
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f'unsupported dimensions: {dims}')
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(
c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
def repeat_noise():
torch.randn(
(1, *shape[1:]), device=device).repeat(shape[0],
*((1, ) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) # noqa
return repeat_noise() if repeat else noise()

View File

@@ -0,0 +1,93 @@
import numpy as np
import torch
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(
self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(
self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1, 2, 3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar
+ torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, 'at least one argument must be a Tensor'
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
tmp = ((mean1 - mean2)**2) * torch.exp(-logvar2)
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
+ tmp)

View File

@@ -0,0 +1,87 @@
import torch
from torch import nn
class LitEma(nn.Module):
def __init__(self, model, decay=0.9999, use_num_upates=True):
super().__init__()
if decay < 0.0 or decay > 1.0:
raise ValueError('Decay must be between 0 and 1')
self.m_name2s_name = {}
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
self.register_buffer(
'num_updates',
torch.tensor(0, dtype=torch.int)
if use_num_upates else torch.tensor(-1, dtype=torch.int))
for name, p in model.named_parameters():
if p.requires_grad:
# remove as '.'-character is not allowed in buffers
s_name = name.replace('.', '')
self.m_name2s_name.update({name: s_name})
self.register_buffer(s_name, p.clone().detach().data)
self.collected_params = []
def reset_num_updates(self):
del self.num_updates
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
def forward(self, model):
decay = self.decay
if self.num_updates >= 0:
self.num_updates += 1
tmp = (1 + self.num_updates) / (10 + self.num_updates)
decay = min(self.decay, tmp)
one_minus_decay = 1.0 - decay
with torch.no_grad():
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
sname = self.m_name2s_name[key]
shadow_params[sname] = shadow_params[sname].type_as(
m_param[key])
tmp = shadow_params[sname] - m_param[key]
shadow_params[sname].sub_(one_minus_decay * tmp)
else:
assert key not in self.m_name2s_name
def copy_to(self, model):
m_param = dict(model.named_parameters())
shadow_params = dict(self.named_buffers())
for key in m_param:
if m_param[key].requires_grad:
m_param[key].data.copy_(
shadow_params[self.m_name2s_name[key]].data)
else:
assert key not in self.m_name2s_name
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)

View File

@@ -0,0 +1,371 @@
import os
import open_clip
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
T5Tokenizer)
from ....dinov2 import hubconf
from ....ldm.util import count_params
class LayerNormFp32(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(
x.to(torch.float32), self.normalized_shape, self.weight, self.bias,
self.eps)
return x.to(orig_type)
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
self.eps)
return x.to(orig_type)
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
self.n_classes = n_classes
self.ucg_rate = ucg_rate
def forward(self, batch, key=None, disable_dropout=False):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
if self.ucg_rate > 0. and not disable_dropout:
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
c = mask * c + (1 - mask) * torch.ones_like(c) * (
self.n_classes - 1)
c = c.long()
c = self.embedding(c)
return c
def get_unconditional_conditioning(self, bs, device='cuda'):
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
uc = torch.ones((bs, ), device=device) * uc_class
uc = {self.key: uc}
return uc
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self,
version='google/t5-v1_1-large',
device='cuda',
max_length=77,
freeze=True
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version)
self.transformer = T5EncoderModel.from_pretrained(version)
self.device = device
self.max_length = max_length # TODO: typical value?
if freeze:
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
LAYERS = ['last', 'pooled', 'hidden']
def __init__(self,
version='openai/clip-vit-large-patch14',
device='cuda',
max_length=77,
freeze=True,
layer='last',
layer_idx=None): # clip-vit-base-patch32
super().__init__()
assert layer in self.LAYERS
self.tokenizer = CLIPTokenizer.from_pretrained(version)
self.transformer = CLIPTextModel.from_pretrained(version)
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
self.layer_idx = layer_idx
if layer == 'hidden':
assert layer_idx is not None
assert 0 <= abs(layer_idx) <= 12
def freeze(self):
self.transformer = self.transformer.eval()
# self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(
text,
truncation=True,
max_length=self.max_length,
return_length=True,
return_overflowing_tokens=False,
padding='max_length',
return_tensors='pt')
tokens = batch_encoding['input_ids'].to(self.device)
outputs = self.transformer(
input_ids=tokens, output_hidden_states=self.layer == 'hidden')
if self.layer == 'last':
z = outputs.last_hidden_state
elif self.layer == 'pooled':
z = outputs.pooler_output[:, None, :]
else:
z = outputs.hidden_states[self.layer_idx]
return z
def encode(self, text):
return self(text)
class FrozenOpenCLIPEmbedder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for text
"""
LAYERS = [
# "pooled",
'last',
'penultimate'
]
def __init__(self,
arch='ViT-H-14',
version='laion2b_s32b_b79k',
device='cuda',
max_length=77,
freeze=True,
layer='last'):
super().__init__()
assert layer in self.LAYERS
model, _, _ = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
del model.visual
self.model = model
self.device = device
self.max_length = max_length
if freeze:
self.freeze()
self.layer = layer
if self.layer == 'last':
self.layer_idx = 0
elif self.layer == 'penultimate':
self.layer_idx = 1
else:
raise NotImplementedError()
def freeze(self):
self.model = self.model.eval()
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
tokens = open_clip.tokenize(text)
z = self.encode_with_transformer(tokens.to(self.device))
return z
def encode_with_transformer(self, text):
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
x = x + self.model.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.model.ln_final(x)
return x
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
for i, r in enumerate(self.model.transformer.resblocks):
if i == len(self.model.transformer.resblocks) - self.layer_idx:
break
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
):
x = checkpoint(r, x, attn_mask)
else:
x = r(x, attn_mask=attn_mask)
return x
def encode(self, text):
return self(text)
class FrozenCLIPT5Encoder(AbstractEncoder):
def __init__(self,
clip_version='openai/clip-vit-large-patch14',
t5_version='google/t5-v1_1-xl',
device='cuda',
clip_max_length=77,
t5_max_length=77):
super().__init__()
self.clip_encoder = FrozenCLIPEmbedder(
clip_version, device, max_length=clip_max_length)
self.t5_encoder = FrozenT5Embedder(
t5_version, device, max_length=t5_max_length)
print(
f'{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, '
f'{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.'
)
def encode(self, text):
return self(text)
def forward(self, text):
clip_z = self.clip_encoder.encode(text)
t5_z = self.t5_encoder.encode(text)
return [clip_z, t5_z]
class FrozenOpenCLIPImageEncoder(AbstractEncoder):
"""
Uses the OpenCLIP transformer encoder for image
"""
def __init__(self,
arch='ViT-H-14',
version='laion2b_s32b_b79k',
device='cuda',
freeze=True):
super().__init__()
model, _, preprocess = open_clip.create_model_and_transforms(
arch, device=torch.device('cpu'), pretrained=version)
del model.transformer
self.model = model
self.model.visual.output_tokens = True
self.device = device
if freeze:
self.freeze()
self.image_mean = torch.tensor(
[0.48145466, 0.4578275,
0.40821073]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
self.image_std = torch.tensor(
[0.26862954, 0.26130258,
0.275777]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
self.projector_token = nn.Linear(1280, 1024)
self.projector_embed = nn.Linear(1024, 1024)
def freeze(self):
self.model.visual.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image):
if isinstance(image, list):
image = torch.cat(image, 0)
image = (image.to(self.device) - self.image_mean.to(
self.device)) / self.image_std.to(self.device)
image_features, tokens = self.model.visual(image)
image_features = image_features.unsqueeze(1)
image_features = self.projector_embed(image_features)
tokens = self.projector_token(tokens)
hint = torch.cat([image_features, tokens], 1)
return hint
def encode(self, image):
return self(image)
class FrozenDinoV2Encoder(AbstractEncoder):
"""
Uses the DINOv2 encoder for image
"""
def __init__(self, model_path, device='cuda', freeze=True):
DINOv2_weight_path = model_path
super().__init__()
dinov2 = hubconf.dinov2_vitg14()
state_dict = torch.load(DINOv2_weight_path)
dinov2.load_state_dict(state_dict, strict=False)
self.model = dinov2.to(device)
self.device = device
if freeze:
self.freeze()
self.image_mean = torch.tensor(
[0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
self.image_std = torch.tensor(
[0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
self.projector = nn.Linear(1536, 1024)
def freeze(self):
self.model.eval()
for param in self.model.parameters():
param.requires_grad = False
def forward(self, image):
if isinstance(image, list):
image = torch.cat(image, 0)
image = (image.to(self.device) - self.image_mean.to(
self.device)) / self.image_std.to(self.device)
features = self.model.forward_features(image)
tokens = features['x_norm_patchtokens']
image_features = features['x_norm_clstoken']
image_features = image_features.unsqueeze(1)
hint = torch.cat([image_features, tokens], 1) # 8,257,1024
hint = self.projector(hint)
return hint
def encode(self, image):
return self(image)

View File

@@ -0,0 +1,221 @@
import importlib
from inspect import isfunction
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
from torch import optim
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new('RGB', wh, color='white')
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = '\n'.join(xc[bi][start:start + nc]
for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill='black', font=font)
except UnicodeEncodeError:
print('Cant encode string for logging. Skipping.')
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
)
return total_params
def instantiate_from_config(config):
if 'target' not in config:
if config == '__is_first_stage__':
return None
elif config == '__is_unconditional__':
return None
raise KeyError('Expected key `target` to instantiate.')
return get_obj_from_str(config['target'])(**config.get('params', dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit('.', 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self,
params,
lr=1.e-3,
betas=(0.9, 0.999),
eps=1.e-8,
weight_decay=1.e-2,
amsgrad=False,
ema_decay=0.9999,
ema_power=1.,
param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError('Invalid learning rate: {}'.format(lr))
if not 0.0 <= eps:
raise ValueError('Invalid epsilon value: {}'.format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError('Invalid beta parameter at index 0: {}'.format(
betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError('Invalid beta parameter at index 1: {}'.format(
betas[1]))
if not 0.0 <= weight_decay:
raise ValueError(
'Invalid weight_decay value: {}'.format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
amsgrad=amsgrad,
ema_decay=ema_decay,
ema_power=ema_power,
param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError(
'AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(
p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(
params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
for param, ema_param in zip(params_with_grad,
ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(
param.float(), alpha=1 - cur_ema_decay)
return loss

View File

@@ -0,0 +1,2 @@
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
from . import ldm

View File

@@ -0,0 +1,158 @@
import pickle
import numpy as np
import cv2
from skimage.io import imread
def save_pickle(data, pkl_path):
# os.system('mkdir -p {}'.format(os.path.dirname(pkl_path)))
with open(pkl_path, 'wb') as f:
pickle.dump(data, f)
def read_pickle(pkl_path):
with open(pkl_path, 'rb') as f:
return pickle.load(f)
def draw_epipolar_line(F, img0, img1, pt0, color):
h1,w1=img1.shape[:2]
hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None]
l = F @ hpt
l = l[:, 0]
a, b, c = l[0], l[1], l[2]
pt1 = np.asarray([0, -c / b]).astype(np.int32)
pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32)
img0 = cv2.circle(img0, tuple(pt0.astype(np.int32)), 5, color, 2)
img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2)
return img0, img1
def draw_epipolar_lines(F, img0, img1,num=20):
img0,img1=img0.copy(),img1.copy()
h0, w0, _ = img0.shape
h1, w1, _ = img1.shape
for k in range(num):
color = np.random.randint(0, 255, [3], dtype=np.int32)
color = [int(c) for c in color]
pt = np.random.uniform(0, 1, 2)
pt[0] *= w0
pt[1] *= h0
pt = pt.astype(np.int32)
img0, img1 = draw_epipolar_line(F, img0, img1, pt, color)
return img0, img1
def compute_F(K1, K2, Rt0, Rt1=None):
if Rt1 is None:
R, t = Rt0[:,:3], Rt0[:,3:]
else:
Rt = compute_dR_dt(Rt0,Rt1)
R, t = Rt[:,:3], Rt[:,3:]
A = K1 @ R.T @ t # [3,1]
C = np.asarray([[0,-A[2,0],A[1,0]],
[A[2,0],0,-A[0,0]],
[-A[1,0],A[0,0],0]])
F = (np.linalg.inv(K2)).T @ R @ K1.T @ C
return F
def compute_dR_dt(Rt0, Rt1):
R0, t0 = Rt0[:,:3], Rt0[:,3:]
R1, t1 = Rt1[:,:3], Rt1[:,3:]
dR = np.dot(R1, R0.T)
dt = t1 - np.dot(dR, t0)
return np.concatenate([dR, dt], -1)
def concat_images(img0,img1,vert=False):
if not vert:
h0,h1=img0.shape[0],img1.shape[0],
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
img = np.concatenate([img0, img1], axis=1)
else:
w0,w1=img0.shape[1],img1.shape[1]
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
img = np.concatenate([img0, img1], axis=0)
return img
def concat_images_list(*args,vert=False):
if len(args)==1: return args[0]
img_out=args[0]
for img in args[1:]:
img_out=concat_images(img_out,img,vert)
return img_out
def pose_inverse(pose):
R = pose[:,:3].T
t = - R @ pose[:,3:]
return np.concatenate([R,t],-1)
def project_points(pts,RT,K):
pts = np.matmul(pts,RT[:,:3].transpose())+RT[:,3:].transpose()
pts = np.matmul(pts,K.transpose())
dpt = pts[:,2]
mask0 = (np.abs(dpt)<1e-4) & (np.abs(dpt)>0)
if np.sum(mask0)>0: dpt[mask0]=1e-4
mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
if np.sum(mask1)>0: dpt[mask1]=-1e-4
pts2d = pts[:,:2]/dpt[:,None]
return pts2d, dpt
def draw_keypoints(img, kps, colors=None, radius=2):
out_img=img.copy()
for pi, pt in enumerate(kps):
pt = np.round(pt).astype(np.int32)
if colors is not None:
color=[int(c) for c in colors[pi]]
cv2.circle(out_img, tuple(pt), radius, color, -1)
else:
cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1)
return out_img
def output_points(fn,pts,colors=None):
with open(fn, 'w') as f:
for pi, pt in enumerate(pts):
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
if colors is not None:
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
f.write('\n')
DEPTH_MAX, DEPTH_MIN = 2.4, 0.6
DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63
def read_depth_objaverse(depth_fn):
depth = imread(depth_fn)
depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN
mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX)
return depth, mask
def mask_depth_to_pts(mask,depth,K,rgb=None):
hs,ws=np.nonzero(mask)
depth=depth[hs,ws]
pts=np.asarray([ws,hs,depth],np.float32).transpose()
pts[:,:2]*=pts[:,2:]
if rgb is not None:
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws]
else:
return np.dot(pts, np.linalg.inv(K).transpose())
def transform_points_pose(pts, pose):
R, t = pose[:, :3], pose[:, 3]
if len(pts.shape)==1:
return (R @ pts[:,None] + t[:,None])[:,0]
return pts @ R.T + t[None,:]
def pose_apply(pose,pts):
return transform_points_pose(pts, pose)
def downsample_gaussian_blur(img, ratio):
sigma = (1 / ratio) / 3
# ksize=np.ceil(2*sigma)
ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1))
ksize = ksize + 1 if ksize % 2 == 0 else ksize
img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
return img

View File

@@ -0,0 +1,443 @@
import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from contextlib import contextmanager
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
class VQModel(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
n_embed,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
batch_resize_range=None,
scheduler_config=None,
lr_g_factor=1.0,
remap=None,
sane_index_shape=False, # tell vector quantizer to return indices as bhw
use_ema=False
):
super().__init__()
self.embed_dim = embed_dim
self.n_embed = n_embed
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
remap=remap,
sane_index_shape=sane_index_shape)
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
self.batch_resize_range = batch_resize_range
if self.batch_resize_range is not None:
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
self.use_ema = use_ema
if self.use_ema:
self.model_ema = LitEma(self)
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
self.scheduler_config = scheduler_config
self.lr_g_factor = lr_g_factor
@contextmanager
def ema_scope(self, context=None):
if self.use_ema:
self.model_ema.store(self.parameters())
self.model_ema.copy_to(self)
if context is not None:
print(f"{context}: Switched to EMA weights")
try:
yield None
finally:
if self.use_ema:
self.model_ema.restore(self.parameters())
if context is not None:
print(f"{context}: Restored training weights")
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
missing, unexpected = self.load_state_dict(sd, strict=False)
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
if len(missing) > 0:
print(f"Missing Keys: {missing}")
print(f"Unexpected Keys: {unexpected}")
def on_train_batch_end(self, *args, **kwargs):
if self.use_ema:
self.model_ema(self)
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
quant, emb_loss, info = self.quantize(h)
return quant, emb_loss, info
def encode_to_prequant(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, quant):
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
def decode_code(self, code_b):
quant_b = self.quantize.embed_code(code_b)
dec = self.decode(quant_b)
return dec
def forward(self, input, return_pred_indices=False):
quant, diff, (_,_,ind) = self.encode(input)
dec = self.decode(quant)
if return_pred_indices:
return dec, diff, ind
return dec, diff
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
if self.batch_resize_range is not None:
lower_size = self.batch_resize_range[0]
upper_size = self.batch_resize_range[1]
if self.global_step <= 4:
# do the first few batches with max size to avoid later oom
new_resize = upper_size
else:
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
if new_resize != x.shape[2]:
x = F.interpolate(x, size=new_resize, mode="bicubic")
x = x.detach()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
# https://github.com/pytorch/pytorch/issues/37142
# try not to fool the heuristics
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
if optimizer_idx == 0:
# autoencode
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train",
predicted_indices=ind)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return aeloss
if optimizer_idx == 1:
# discriminator
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
return discloss
def validation_step(self, batch, batch_idx):
log_dict = self._validation_step(batch, batch_idx)
with self.ema_scope():
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
return log_dict
def _validation_step(self, batch, batch_idx, suffix=""):
x = self.get_input(batch, self.image_key)
xrec, qloss, ind = self(x, return_pred_indices=True)
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
self.global_step,
last_layer=self.get_last_layer(),
split="val"+suffix,
predicted_indices=ind
)
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
self.log(f"val{suffix}/rec_loss", rec_loss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
self.log(f"val{suffix}/aeloss", aeloss,
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
if version.parse(pl.__version__) >= version.parse('1.4.0'):
del log_dict_ae[f"val{suffix}/rec_loss"]
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr_d = self.learning_rate
lr_g = self.lr_g_factor*self.learning_rate
print("lr_d", lr_d)
print("lr_g", lr_g)
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quantize.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr_g, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr_d, betas=(0.5, 0.9))
if self.scheduler_config is not None:
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [
{
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
{
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
'interval': 'step',
'frequency': 1
},
]
return [opt_ae, opt_disc], scheduler
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if only_inputs:
log["inputs"] = x
return log
xrec, _ = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["inputs"] = x
log["reconstructions"] = xrec
if plot_ema:
with self.ema_scope():
xrec_ema, _ = self(x)
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
log["reconstructions_ema"] = xrec_ema
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class VQModelInterface(VQModel):
def __init__(self, embed_dim, *args, **kwargs):
super().__init__(embed_dim=embed_dim, *args, **kwargs)
self.embed_dim = embed_dim
def encode(self, x):
h = self.encoder(x)
h = self.quant_conv(h)
return h
def decode(self, h, force_not_quantize=False):
# also go through quantization layer
if not force_not_quantize:
quant, emb_loss, info = self.quantize(h)
else:
quant = h
quant = self.post_quant_conv(quant)
dec = self.decoder(quant)
return dec
class AutoencoderKL(pl.LightningModule):
def __init__(self,
ddconfig,
lossconfig,
embed_dim,
ckpt_path=None,
ignore_keys=[],
image_key="image",
colorize_nlabels=None,
monitor=None,
):
super().__init__()
self.image_key = image_key
self.encoder = Encoder(**ddconfig)
self.decoder = Decoder(**ddconfig)
self.loss = instantiate_from_config(lossconfig)
assert ddconfig["double_z"]
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
self.embed_dim = embed_dim
if colorize_nlabels is not None:
assert type(colorize_nlabels)==int
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
if monitor is not None:
self.monitor = monitor
if ckpt_path is not None:
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
def init_from_ckpt(self, path, ignore_keys=list()):
sd = torch.load(path, map_location="cpu")["state_dict"]
keys = list(sd.keys())
for k in keys:
for ik in ignore_keys:
if k.startswith(ik):
print("Deleting key {} from state_dict.".format(k))
del sd[k]
self.load_state_dict(sd, strict=False)
print(f"Restored from {path}")
def encode(self, x):
h = self.encoder(x)
moments = self.quant_conv(h)
posterior = DiagonalGaussianDistribution(moments)
return posterior
def decode(self, z):
z = self.post_quant_conv(z)
dec = self.decoder(z)
return dec
def forward(self, input, sample_posterior=True):
posterior = self.encode(input)
if sample_posterior:
z = posterior.sample()
else:
z = posterior.mode()
dec = self.decode(z)
return dec, posterior
def get_input(self, batch, k):
x = batch[k]
if len(x.shape) == 3:
x = x[..., None]
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
return x
def training_step(self, batch, batch_idx, optimizer_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
if optimizer_idx == 0:
# train encoder+decoder+logvar
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return aeloss
if optimizer_idx == 1:
# train the discriminator
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
last_layer=self.get_last_layer(), split="train")
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
return discloss
def validation_step(self, batch, batch_idx):
inputs = self.get_input(batch, self.image_key)
reconstructions, posterior = self(inputs)
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
last_layer=self.get_last_layer(), split="val")
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
last_layer=self.get_last_layer(), split="val")
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
self.log_dict(log_dict_ae)
self.log_dict(log_dict_disc)
return self.log_dict
def configure_optimizers(self):
lr = self.learning_rate
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
list(self.decoder.parameters())+
list(self.quant_conv.parameters())+
list(self.post_quant_conv.parameters()),
lr=lr, betas=(0.5, 0.9))
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
lr=lr, betas=(0.5, 0.9))
return [opt_ae, opt_disc], []
def get_last_layer(self):
return self.decoder.conv_out.weight
@torch.no_grad()
def log_images(self, batch, only_inputs=False, **kwargs):
log = dict()
x = self.get_input(batch, self.image_key)
x = x.to(self.device)
if not only_inputs:
xrec, posterior = self(x)
if x.shape[1] > 3:
# colorize with random projection
assert xrec.shape[1] > 3
x = self.to_rgb(x)
xrec = self.to_rgb(xrec)
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
log["reconstructions"] = xrec
log["inputs"] = x
return log
def to_rgb(self, x):
assert self.image_key == "segmentation"
if not hasattr(self, "colorize"):
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
x = F.conv2d(x, weight=self.colorize)
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
return x
class IdentityFirstStage(torch.nn.Module):
def __init__(self, *args, vq_interface=False, **kwargs):
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
super().__init__()
def encode(self, x, *args, **kwargs):
return x
def decode(self, x, *args, **kwargs):
return x
def quantize(self, x, *args, **kwargs):
if self.vq_interface:
return x, None, [None, None, None]
return x
def forward(self, x, *args, **kwargs):
return x

View File

@@ -0,0 +1,673 @@
from pathlib import Path
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.io import imsave
from torch.optim.lr_scheduler import LambdaLR
from tqdm import tqdm
from modelscope.models.cv.image_to_3d.ldm.base_utils import read_pickle, concat_images_list
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.modules import FrozenCLIPImageEmbedder
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def disable_training_module(module: nn.Module):
module = module.eval()
module.train = disabled_train
for para in module.parameters():
para.requires_grad = False
return module
def repeat_to_batch(tensor, B, VN):
t_shape = tensor.shape
ones = [1 for _ in range(len(t_shape)-1)]
tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:])
return tensor_new
class UNetWrapper(nn.Module):
def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True):
super().__init__()
self.diffusion_model = instantiate_from_config(diff_model_config)
self.drop_conditions = drop_conditions
self.drop_scheme=drop_scheme
self.use_zero_123 = use_zero_123
def drop(self, cond, mask):
shape = cond.shape
B = shape[0]
cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond
return cond
def get_trainable_parameters(self):
return self.diffusion_model.get_trainable_parameters()
def get_drop_scheme(self, B, device):
if self.drop_scheme=='default':
random = torch.rand(B, dtype=torch.float32, device=device)
drop_clip = (random > 0.15) & (random <= 0.2)
drop_volume = (random > 0.1) & (random <= 0.15)
drop_concat = (random > 0.05) & (random <= 0.1)
drop_all = random <= 0.05
else:
raise NotImplementedError
return drop_clip, drop_volume, drop_concat, drop_all
def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False):
"""
@param x: B,4,H,W
@param t: B,
@param clip_embed: B,M,768
@param volume_feats: B,C,D,H,W
@param x_concat: B,C,H,W
@param is_train:
@return:
"""
if self.drop_conditions and is_train:
B = x.shape[0]
drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device)
clip_mask = 1.0 - (drop_clip | drop_all).float()
clip_embed = self.drop(clip_embed, clip_mask)
volume_mask = 1.0 - (drop_volume | drop_all).float()
for k, v in volume_feats.items():
volume_feats[k] = self.drop(v, mask=volume_mask)
concat_mask = 1.0 - (drop_concat | drop_all).float()
x_concat = self.drop(x_concat, concat_mask)
if self.use_zero_123:
# zero123 does not multiply this when encoding, maybe a bug for zero123
first_stage_scale_factor = 0.18215
x_concat_ = x_concat * 1.0
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
else:
x_concat_ = x_concat
x = torch.cat([x, x_concat_], 1)
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
return pred
def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale):
x_ = torch.cat([x] * 2, 0)
t_ = torch.cat([t] * 2, 0)
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0)
v_ = {}
for k, v in volume_feats.items():
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
if self.use_zero_123:
# zero123 does not multiply this when encoding, maybe a bug for zero123
first_stage_scale_factor = 0.18215
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
x_ = torch.cat([x_, x_concat_], 1)
s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2)
s = s_uc + unconditional_scale * (s - s_uc)
return s
class SpatialVolumeNet(nn.Module):
def __init__(self, time_dim, view_dim, view_num,
input_image_size=256, frustum_volume_depth=48,
spatial_volume_size=32, spatial_volume_length=0.5,
frustum_volume_length=0.86603 # sqrt(3)/2
):
super().__init__()
self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16)
self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512))
self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512))
self.frustum_volume_length = frustum_volume_length
self.input_image_size = input_image_size
self.spatial_volume_size = spatial_volume_size
self.spatial_volume_length = spatial_volume_length
self.frustum_volume_size = self.input_image_size // 8
self.frustum_volume_depth = frustum_volume_depth
self.time_dim = time_dim
self.view_dim = view_dim
self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin
def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks):
"""
@param x: B,N,4,H,W
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param target_poses: N,3,4
@param target_Ks: N,3,3
@return:
"""
B, N, _, H, W = x.shape
V = self.spatial_volume_size
device = x.device
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
# encode source features
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
v_embed_ = v_embed
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
# extract 2D image features
spatial_volume_feats = []
# project source features
for ni in range(0, N):
pose_source_ = target_poses[:, ni]
K_source_ = target_Ks[:, ni]
x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni])
C = x_.shape[1]
coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2)
unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True)
unproj_feats_ = unproj_feats_.view(B, C, V, V, V)
spatial_volume_feats.append(unproj_feats_)
spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V
N = spatial_volume_feats.shape[1]
spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V)
spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32
return spatial_volume_feats
def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices):
"""
@param spatial_volume: B,C,V,V,V
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param poses: N,3,4
@param Ks: N,3,3
@param target_indices: B,TN
@return: B*TN,C,H,W
"""
B, TN = target_indices.shape
H, W = self.frustum_volume_size, self.frustum_volume_size
D = self.frustum_volume_depth
V = self.spatial_volume_size
near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length
far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length
target_indices = target_indices.view(B*TN) # B*TN
poses_ = poses[target_indices] # B*TN,3,4
Ks_ = Ks[target_indices] # B*TN,3,4
volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W
volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length]
volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3
spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V)
volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W
v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN
t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1)
volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_)
return volume_feats_dict, volume_depth
"""
SyncDreamer is a SoTA Novel View Synthesis model which can generate 16 consistent views seamlessly.
Please refer to: https://arxiv.org/abs/2309.03453 for more technique details.
"""
class SyncMultiviewDiffusion(pl.LightningModule):
def __init__(self, unet_config, scheduler_config,
finetune_unet=False, finetune_projection=True,
view_num=16, image_size=256,
cfg_scale=3.0, output_num=8, batch_view_num=4,
drop_conditions=False, drop_scheme='default',
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
super().__init__()
self.finetune_unet = finetune_unet
self.finetune_projection = finetune_projection
self.view_num = view_num
self.viewpoint_dim = 4
self.output_num = output_num
self.image_size = image_size
self.batch_view_num = batch_view_num
self.cfg_scale = cfg_scale
self.clip_image_encoder_path = clip_image_encoder_path
self._init_time_step_embedding()
self._init_first_stage()
self._init_schedule()
self._init_multiview()
self._init_clip_image_encoder()
self._init_clip_projection()
self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num)
self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme)
self.scheduler_config = scheduler_config
latent_size = image_size//8
self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
def _init_clip_projection(self):
self.cc_projection = nn.Linear(772, 768)
nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
nn.init.zeros_(list(self.cc_projection.parameters())[1])
self.cc_projection.requires_grad_(True)
if not self.finetune_projection:
disable_training_module(self.cc_projection)
def _init_multiview(self):
K, azs, _, _, poses = read_pickle(self.clip_image_encoder_path.replace("ViT-L-14.pt",f'camera-{self.view_num}.pkl'))
default_image_size = 256
ratio = self.image_size/default_image_size
K = np.diag([ratio,ratio,1]) @ K
K = torch.from_numpy(K.astype(np.float32)) # [3,3]
K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3
poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4
self.register_buffer('poses', poses)
self.register_buffer('Ks', K)
azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0
self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32)))
def get_viewpoint_embedding(self, batch_size, elevation_ref):
"""
@param batch_size:
@param elevation_ref: B
@return:
"""
azimuth_input = self.azimuth[0].unsqueeze(0) # 1
azimuth_target = self.azimuth # N
elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!!
elevation_target = -np.deg2rad(30)
d_e = elevation_target - elevation_input # B
N = self.azimuth.shape[0]
B = batch_size
d_e = d_e.unsqueeze(1).repeat(1, N)
d_a = azimuth_target - azimuth_input # N
d_a = d_a.unsqueeze(0).repeat(B, 1)
d_z = torch.zeros_like(d_a)
embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4
return embedding
def _init_first_stage(self):
first_stage_config={
"target": "modelscope.models.cv.image_to_3d.ldm.models.autoencoder.AutoencoderKL",
"params": {
"embed_dim": 4,
"monitor": "val/rec_loss",
"ddconfig":{
"double_z": True,
"z_channels": 4,
"resolution": self.image_size,
"in_channels": 3,
"out_ch": 3,
"ch": 128,
"ch_mult": [1,2,4,4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0
},
"lossconfig": {"target": "torch.nn.Identity"},
}
}
self.first_stage_scale_factor = 0.18215
self.first_stage_model = instantiate_from_config(first_stage_config)
self.first_stage_model = disable_training_module(self.first_stage_model)
def _init_clip_image_encoder(self):
self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path)
self.clip_image_encoder = disable_training_module(self.clip_image_encoder)
def _init_schedule(self):
self.num_timesteps = 1000
linear_start = 0.00085
linear_end = 0.0120
num_timesteps = 1000
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T
assert betas.shape[0] == self.num_timesteps
# all in float64 first
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0) # T
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T
posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20))
posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10)
self.register_buffer("betas", betas.float())
self.register_buffer("alphas", alphas.float())
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float())
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float())
self.register_buffer("posterior_variance", posterior_variance.float())
self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float())
def _init_time_step_embedding(self):
self.time_embed_dim = 256
self.time_embed = nn.Sequential(
nn.Linear(self.time_embed_dim, self.time_embed_dim),
nn.SiLU(True),
nn.Linear(self.time_embed_dim, self.time_embed_dim),
)
def encode_first_stage(self, x, sample=True):
with torch.no_grad():
posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8
if sample:
return posterior.sample().detach() * self.first_stage_scale_factor
else:
return posterior.mode().detach() * self.first_stage_scale_factor
def decode_first_stage(self, z):
with torch.no_grad():
z = 1. / self.first_stage_scale_factor * z
return self.first_stage_model.decode(z)
def prepare(self, batch):
# encode target
if 'target_image' in batch:
image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w
N = image_target.shape[1]
x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)]
x = torch.stack(x, 1) # b,n,4,h//8,w//8
else:
x = None
image_input = batch['input_image'].permute(0, 3, 1, 2)
elevation_input = batch['input_elevation'][:, 0] # b
x_input = self.encode_first_stage(image_input)
input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input}
with torch.no_grad():
clip_embed = self.clip_image_encoder.encode(image_input)
return x, clip_embed, input_info
def embed_time(self, t):
t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED
t_embed = self.time_embed(t_embed) # B,TED
return t_embed
def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index):
"""
@param x_input: B,4,H,W
@param spatial_volume: B,C,V,V,V
@param clip_embed: B,1,768
@param t_embed: B,t_dim
@param v_embed: B,N,v_dim
@param target_index: B,TN
@return:
tensors of size B*TN,*
"""
B, _, H, W = x_input.shape
frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index)
# clip
TN = target_index.shape[1]
v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim
clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768)
clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768
x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W)
x_concat = x_input_
return clip_embed_, frustum_volume_feats, x_concat
def training_step(self, batch):
B = batch['image'].shape[0]
time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
x, clip_embed, input_info = self.prepare(batch)
x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W
N = self.view_num
target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1
v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim
t_embed = self.embed_time(time_steps)
spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks)
clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index)
x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W
noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
# loss simple for diffusion
loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none')
loss = loss_simple.mean()
self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True)
# log others
lr = self.optimizers().param_groups[0]['lr']
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
return loss
def add_noise(self, x_start, t):
"""
@param x_start: B,*
@param t: B,
@return:
"""
B = x_start.shape[0]
noise = torch.randn_like(x_start) # B,*
sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B,
sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B
sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
return x_noisy, noise
def sample(self, batch, cfg_scale, batch_view_num, use_ddim=True,
return_inter_results=False, inter_interval=50, inter_view_interval=2):
_, clip_embed, input_info = self.prepare(batch)
if use_ddim:
x_sample, inter = self.ddim.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
else:
raise NotImplementedError
N = x_sample.shape[1]
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
if return_inter_results:
torch.cuda.synchronize()
torch.cuda.empty_cache()
inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W
B,N,T,C,H,W = inter.shape
inter_results = []
for ni in tqdm(range(0, N, inter_view_interval)):
inter_results_ = []
for ti in range(T):
inter_results_.append(self.decode_first_stage(inter[:, ni, ti]))
inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W
inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W
return x_sample, inter_results
else:
return x_sample
def log_image(self, x_sample, batch, step, output_dir, only_first_row=False):
process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
B = x_sample.shape[0]
N = x_sample.shape[1]
image_cond = []
for bi in range(B):
img_pr_ = concat_images_list(process(batch['ref_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)])
img_gt_ = concat_images_list(process(batch['ref_image'][bi]),*[process(batch['image'][bi, ni]) for ni in range(N)])
if not only_first_row or bi==0:
image_cond.append(concat_images_list(img_gt_, img_pr_, vert=True))
else:
image_cond.append(img_pr_)
output_dir = Path(output_dir)
imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True))
@torch.no_grad()
def validation_step(self, batch, batch_idx):
if batch_idx==0 and self.global_rank==0:
self.eval()
step = self.global_step
batch_ = {}
for k, v in batch.items(): batch_[k] = v[:self.output_num]
x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
output_dir = Path(self.image_dir) / 'images' / 'val'
output_dir.mkdir(exist_ok=True, parents=True)
self.log_image(x_sample, batch, step, output_dir=output_dir)
def configure_optimizers(self):
lr = self.learning_rate
print(f'setting learning rate to {lr:.4f} ...')
paras = []
if self.finetune_projection:
paras.append({"params": self.cc_projection.parameters(), "lr": lr},)
if self.finetune_unet:
paras.append({"params": self.model.parameters(), "lr": lr},)
else:
paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},)
paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},)
paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},)
opt = torch.optim.AdamW(paras, lr=lr)
scheduler = instantiate_from_config(self.scheduler_config)
print("Setting up LambdaLR scheduler...")
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
return [opt], scheduler
class SyncDDIMSampler:
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., latent_size=32):
super().__init__()
self.model = model
self.ddpm_num_timesteps = model.num_timesteps
self.latent_size = latent_size
self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta)
self.eta = ddim_eta
def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT
ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT
alphas_cumprod = self.model.alphas_cumprod # T
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT
self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT
self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev))
self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT
self.ddim_sigmas = self.ddim_sigmas.float()
self.ddim_alphas = self.ddim_alphas.float()
self.ddim_alphas_prev = self.ddim_alphas_prev.float()
self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float()
@torch.no_grad()
def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False):
"""
@param x_target_noisy: B,N,4,H,W
@param index: index
@param noise_pred: B,N,4,H,W
@param is_step0: bool
@return:
"""
device = x_target_noisy.device
B,N,_,H,W = x_target_noisy.shape
# apply noise
a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1)
a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1)
sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1)
sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1)
pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
if not is_step0:
noise = sigma_t * torch.randn_like(x_target_noisy)
x_prev = x_prev + noise
return x_prev
@torch.no_grad()
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
"""
@param x_target_noisy: B,N,4,H,W
@param input_info:
@param clip_embed: B,M,768
@param time_steps: B,
@param index: int
@param unconditional_scale:
@param batch_view_num: int
@param is_step0: bool
@return:
"""
x_input, elevation_input = input_info['x'], input_info['elevation']
B, N, C, H, W = x_target_noisy.shape
# construct source data
v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim
t_embed = self.model.embed_time(time_steps) # B,t_dim
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
e_t = []
target_indices = torch.arange(N) # N
for ni in range(0, N, batch_view_num):
x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
VN = x_target_noisy_.shape[1]
x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
time_steps_ = repeat_to_batch(time_steps, B, VN)
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
if unconditional_scale!=1.0:
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
else:
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
e_t.append(noise.view(B,VN,4,H,W))
e_t = torch.cat(e_t, 1)
x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
return x_prev
@torch.no_grad()
def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1):
"""
@param input_info: x, elevation
@param clip_embed: B,M,768
@param unconditional_scale:
@param log_every_t:
@param batch_view_num:
@return:
"""
print(f"unconditional scale {unconditional_scale:.1f}")
C, H, W = 4, self.latent_size, self.latent_size
B = clip_embed.shape[0]
N = self.model.view_num
device = self.model.device
x_target_noisy = torch.randn([B, N, C, H, W], device=device)
timesteps = self.ddim_timesteps
intermediates = {'x_inter': []}
time_range = np.flip(timesteps)
total_steps = timesteps.shape[0]
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
for i, step in enumerate(iterator):
index = total_steps - i - 1 # index in ddim state
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
if index % log_every_t == 0 or index == total_steps - 1:
intermediates['x_inter'].append(x_target_noisy)
return x_target_noisy, intermediates

View File

@@ -0,0 +1,142 @@
import torch
import torch.nn as nn
from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding
class DepthAttention(nn.Module):
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.dim_head = dim_head
self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False)
self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
if output_bias:
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1)
else:
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False)
def forward(self, x, context):
"""
@param x: b,f0,h,w
@param context: b,f1,d,h,w
@return:
"""
hn, hd = self.heads, self.dim_head
b, _, h, w = x.shape
b, _, d, h, w = context.shape
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
attn = sim.softmax(dim=2)
# b,hn,hd,d,h,w * b,hn,1,d,h,w
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
out = out.reshape(b,hn*hd,h,w)
return self.to_out(out)
class DepthTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
super().__init__()
inner_dim = n_heads * d_head
self.proj_in = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1),
nn.GroupNorm(8, inner_dim),
nn.SiLU(True),
)
self.proj_context = nn.Sequential(
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
nn.GroupNorm(8, context_dim),
nn.ReLU(True), # only relu, because we want input is 0, output is 0
)
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn
self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context):
x_in = x
x = self.proj_in(x)
context = self.proj_context(context)
x = self.depth_attn(x, context)
x = self.proj_out(x) + x_in
return x
class DepthWiseAttention(UNetModel):
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
super().__init__(*args, **kwargs)
# num_heads = 4
model_channels = kwargs['model_channels']
channel_mult = kwargs['channel_mult']
d0,d1,d2,d3 = volume_dims
# 4
ch = model_channels*channel_mult[2]
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3)
self.output_conditions=nn.ModuleList()
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8}
# 8
ch = model_channels*channel_mult[2]
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
# 16
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
ch = model_channels*channel_mult[1]
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
# 32
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
ch = model_channels*channel_mult[0]
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs):
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
emb = self.time_embed(t_emb)
h = x.type(self.dtype)
for index, module in enumerate(self.input_blocks):
h = module(h, emb, context)
hs.append(h)
h = self.middle_block(h, emb, context)
h = self.middle_conditions(h, context=source_dict[h.shape[-1]])
for index, module in enumerate(self.output_blocks):
h = torch.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
if index in self.output_b2c:
layer = self.output_conditions[self.output_b2c[index]]
h = layer(h, context=source_dict[h.shape[-1]])
h = h.type(x.dtype)
return self.out(h)
def get_trainable_parameters(self):
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()]
return paras

View File

@@ -0,0 +1,186 @@
import torch
import torch.nn as nn
class Image2DResBlockWithTV(nn.Module):
def __init__(self, dim, tdim, vdim):
super().__init__()
norm = lambda c: nn.GroupNorm(8, c)
self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
self.conv = nn.Sequential(
norm(dim),
nn.SiLU(True),
nn.Conv2d(dim, dim, 3, 1, 1),
norm(dim),
nn.SiLU(True),
nn.Conv2d(dim, dim, 3, 1, 1),
)
def forward(self, x, t, v):
return x+self.conv(x+self.time_embed(t)+self.view_embed(v))
class NoisyTargetViewEncoder(nn.Module):
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
super().__init__()
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
self.final_out = nn.Sequential(
nn.GroupNorm(8, run_dim),
nn.SiLU(True),
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
)
def forward(self, x, t, v):
B, DT = t.shape
t = t.view(B, DT, 1, 1)
B, DV = v.shape
v = v.view(B, DV, 1, 1)
x = self.init_conv(x)
x = self.out_conv0(x, t, v)
x = self.out_conv1(x, t, v)
x = self.out_conv2(x, t, v)
x = self.final_out(x)
return x
class SpatialUpTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.norm = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
def forward(self, x, t):
x = x + self.t_conv(t)
return self.conv(self.silu(self.norm(x)))
class SpatialTimeBlock(nn.Module):
def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
self.bn = norm_act(x_in_dim)
self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_in_dim, out_dim, 3, stride=stride, padding=1)
def forward(self, x, t):
x = x + self.t_conv(t)
return self.conv(self.silu(self.bn(x)))
class SpatialTime3DNet(nn.Module):
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
super().__init__()
d0, d1, d2, d3 = dims
dt = time_dim
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
def forward(self, x, t):
B, C = t.shape
t = t.view(B, C, 1, 1, 1)
x = self.init_conv(x)
conv0 = self.conv0(x, t)
x = self.conv1(conv0, t)
x = self.conv2_0(x, t)
conv2 = self.conv2_1(x, t)
x = self.conv3(conv2, t)
x = self.conv4_0(x, t)
conv4 = self.conv4_1(x, t)
x = self.conv5(conv4, t)
x = self.conv6_0(x, t)
x = self.conv6_1(x, t)
x = conv4 + self.conv7(x, t)
x = conv2 + self.conv8(x, t)
x = conv0 + self.conv9(x, t)
return x
class FrustumTVBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.bn = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.bn(x)))
class FrustumTVUpBlock(nn.Module):
def __init__(self, x_dim, t_dim, v_dim, out_dim):
super().__init__()
norm_act = lambda c: nn.GroupNorm(8, c)
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
self.norm = norm_act(x_dim)
self.silu = nn.SiLU(True)
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
def forward(self, x, t, v):
x = x + self.t_conv(t) + self.v_conv(v)
return self.conv(self.silu(self.norm(x)))
class FrustumTV3DNet(nn.Module):
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
super().__init__()
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
self.conv3 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[2], 2)
self.conv4 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[2], 1)
self.conv5 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[3], 2)
self.conv6 = FrustumTVBlock(dims[3], t_dim, v_dim, dims[3], 1)
self.up0 = FrustumTVUpBlock(dims[3], t_dim, v_dim, dims[2])
self.up1 = FrustumTVUpBlock(dims[2], t_dim, v_dim, dims[1])
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
def forward(self, x, t, v):
B,DT = t.shape
t = t.view(B,DT,1,1,1)
B,DV = v.shape
v = v.view(B,DV,1,1,1)
b, _, d, h, w = x.shape
x0 = self.conv0(x)
x1 = self.conv2(self.conv1(x0, t, v), t, v)
x2 = self.conv4(self.conv3(x1, t, v), t, v)
x3 = self.conv6(self.conv5(x2, t, v), t, v)
x2 = self.up0(x3, t, v) + x2
x1 = self.up1(x2, t, v) + x1
x0 = self.up2(x1, t, v) + x0
return {w: x0, w//2: x1, w//4: x2, w//8: x3}

View File

@@ -0,0 +1,103 @@
import torch
from kornia import create_meshgrid
def project_and_normalize(ref_grid, src_proj, length):
"""
@param ref_grid: b 3 n
@param src_proj: b 4 4
@param length: int
@return: b, n, 2
"""
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
div_val = src_grid[:, -1:]
div_val[div_val<1e-4] = 1e-4
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1
src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
return src_grid
def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
"""
@param x_ratio: float
@param y_ratio: float
@param Ks: b,3,3
@param poses: b,3,4
@return:
"""
rfn = Ks.shape[0]
scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device)
scale_m = torch.diag(scale_m)
ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4
pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device)
pad_vals[:, :, 3] = 1.0
ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4
return ref_prj
def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
B, _, D, H, W = volume_xyz.shape
ratio = warp_size / input_size
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2)
return warp_coords
def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None):
device, dtype = pose_target.device, pose_target.dtype
# compute a depth range on the unit sphere
H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
if near is not None and far is not None :
# near, far b,1,h,w
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
depth_values = depth_values * (far - near) + near # b d h w
depth_values = depth_values.view(B, 1, D, H * W)
else:
near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W)
ratio = volume_size / input_image_size
# creat a grid on the target (reference) view
# H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]
# creat mesh grid: note reference also means target
ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2)
ref_grid = ref_grid.to(device).to(dtype)
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W)
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W)
ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W)
# unproject to space and transfer to world coordinates.
Ks = K
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W)
def near_far_from_unit_sphere_using_camera_poses(camera_poses):
"""
@param camera_poses: b 3 4
@return:
near: b,1
far: b,1
"""
R_w2c = camera_poses[..., :3, :3] # b 3 3
t_w2c = camera_poses[..., :3, 3:] # b 3 1
camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1
# R_w2c.T @ (0,0,1) = z_dir
camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1
camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3
a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
mid = b / a # b 1
near, far = mid - 1.0, mid + 1.0
return near, far

View File

@@ -0,0 +1,336 @@
from inspect import isfunction
import math
import torch
import torch.nn.functional as F
from torch import nn, einsum
from einops import rearrange, repeat
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
# feedforward
class ConvGEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def Normalize(in_channels):
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
class LinearAttention(nn.Module):
def __init__(self, dim, heads=4, dim_head=32):
super().__init__()
self.heads = heads
hidden_dim = dim_head * heads
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
return self.to_out(out)
class SpatialSelfAttention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = rearrange(q, 'b c h w -> b (h w) c')
k = rearrange(k, 'b c h w -> b c (h w)')
w_ = torch.einsum('bij,bjk->bik', q, k)
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = rearrange(v, 'b c h w -> b c (h w)')
w_ = rearrange(w_, 'b i j -> b j i')
h_ = torch.einsum('bij,bjk->bik', v, w_)
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
h_ = self.proj_out(h_)
return x+h_
class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
context_dim = default(context_dim, query_dim)
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
)
def forward(self, x, context=None, mask=None):
h = self.heads
q = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
if exists(mask):
mask = mask>0
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)
class BasicSpatialTransformer(nn.Module):
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
super().__init__()
inner_dim = n_heads * d_head
self.proj_in = nn.Sequential(
nn.GroupNorm(8, dim),
nn.Conv2d(dim, inner_dim, kernel_size=1, stride=1, padding=0),
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
)
self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn
self.out_conv = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
nn.Conv2d(inner_dim, inner_dim, 1, 1),
)
self.proj_out = nn.Sequential(
nn.GroupNorm(8, inner_dim),
nn.ReLU(True),
zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context):
# input
b,_,h,w = x.shape
x_in = x
x = self.proj_in(x)
# attention
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
x = self.attn(x, context) + x
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
# output
x = self.out_conv(x) + x
x = self.proj_out(x) + x_in
return x
class BasicTransformerBlock(nn.Module):
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
super().__init__()
self.disable_self_attn = disable_self_attn
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def forward(self, x, context=None):
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
def _forward(self, x, context=None):
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
x = self.attn2(self.norm2(x), context=context) + x
x = self.ff(self.norm3(x)) + x
return x
class ConvFeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Conv2d(dim, inner_dim, 1, 1, 0),
nn.GELU()
) if not glu else ConvGEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Conv2d(inner_dim, dim_out, 1, 1, 0)
)
def forward(self, x):
return self.net(x)
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data.
First, project the input (aka embedding)
and reshape to b, t, d.
Then apply standard transformer action.
Finally, reshape to image
"""
def __init__(self, in_channels, n_heads, d_head,
depth=1, dropout=0., context_dim=None,
disable_self_attn=False):
super().__init__()
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = Normalize(in_channels)
self.proj_in = nn.Conv2d(in_channels,
inner_dim,
kernel_size=1,
stride=1,
padding=0)
self.transformer_blocks = nn.ModuleList(
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
disable_self_attn=disable_self_attn)
for d in range(depth)]
)
self.proj_out = zero_module(nn.Conv2d(inner_dim,
in_channels,
kernel_size=1,
stride=1,
padding=0))
def forward(self, x, context=None):
# note: if no context is given, cross-attention defaults to self-attention
b, c, h, w = x.shape
x_in = x
x = self.norm(x)
x = self.proj_in(x)
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
for block in self.transformer_blocks:
x = block(x, context=context)
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
x = self.proj_out(x)
return x + x_in

View File

@@ -0,0 +1,835 @@
# pytorch_diffusion + derived encoder decoder
import math
import torch
import torch.nn as nn
import numpy as np
from einops import rearrange
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
from modelscope.models.cv.image_to_3d.ldm.modules.attention import LinearAttention
def get_timestep_embedding(timesteps, embedding_dim):
"""
This matches the implementation in Denoising Diffusion Probabilistic Models:
From Fairseq.
Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
assert len(timesteps.shape) == 1
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
emb = emb.to(device=timesteps.device)
emb = timesteps.float()[:, None] * emb[None, :]
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
if embedding_dim % 2 == 1: # zero pad
emb = torch.nn.functional.pad(emb, (0,1,0,0))
return emb
def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
def Normalize(in_channels, num_groups=32):
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
class Upsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
if self.with_conv:
x = self.conv(x)
return x
class Downsample(nn.Module):
def __init__(self, in_channels, with_conv):
super().__init__()
self.with_conv = with_conv
if self.with_conv:
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=3,
stride=2,
padding=0)
def forward(self, x):
if self.with_conv:
pad = (0,1,0,1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
else:
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
return x
class ResnetBlock(nn.Module):
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
dropout, temb_channels=512):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.norm1 = Normalize(in_channels)
self.conv1 = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if temb_channels > 0:
self.temb_proj = torch.nn.Linear(temb_channels,
out_channels)
self.norm2 = Normalize(out_channels)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
else:
self.nin_shortcut = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x+h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
def __init__(self, in_channels):
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = Normalize(in_channels)
self.q = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.k = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.v = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
self.proj_out = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=1,
stride=1,
padding=0)
def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
h_ = self.proj_out(h_)
return x+h_
def make_attn(in_channels, attn_type="vanilla"):
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
if attn_type == "vanilla":
return AttnBlock(in_channels)
elif attn_type == "none":
return nn.Identity(in_channels)
else:
return LinAttnBlock(in_channels)
class Model(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = self.ch*4
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.use_timestep = use_timestep
if self.use_timestep:
# timestep embedding
self.temb = nn.Module()
self.temb.dense = nn.ModuleList([
torch.nn.Linear(self.ch,
self.temb_ch),
torch.nn.Linear(self.temb_ch,
self.temb_ch),
])
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
skip_in = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
if i_block == self.num_res_blocks:
skip_in = ch*in_ch_mult[i_level]
block.append(ResnetBlock(in_channels=block_in+skip_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x, t=None, context=None):
#assert x.shape[2] == x.shape[3] == self.resolution
if context is not None:
# assume aligned context, cat along channel axis
x = torch.cat((x, context), dim=1)
if self.use_timestep:
# timestep embedding
assert t is not None
temb = get_timestep_embedding(t, self.ch)
temb = self.temb.dense[0](temb)
temb = nonlinearity(temb)
temb = self.temb.dense[1](temb)
else:
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](
torch.cat([h, hs.pop()], dim=1), temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
def get_last_layer(self):
return self.conv_out.weight
class Encoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
**ignore_kwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = torch.nn.Conv2d(in_channels,
self.ch,
kernel_size=3,
stride=1,
padding=1)
curr_res = resolution
in_ch_mult = (1,)+tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = nn.ModuleList()
for i_level in range(self.num_resolutions):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch*in_ch_mult[i_level]
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
down = nn.Module()
down.block = block
down.attn = attn
if i_level != self.num_resolutions-1:
down.downsample = Downsample(block_in, resamp_with_conv)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
2*z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# timestep embedding
temb = None
# downsampling
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level].block[i_block](hs[-1], temb)
if len(self.down[i_level].attn) > 0:
h = self.down[i_level].attn[i_block](h)
hs.append(h)
if i_level != self.num_resolutions-1:
hs.append(self.down[i_level].downsample(hs[-1]))
# middle
h = hs[-1]
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# end
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
attn_type="vanilla", **ignorekwargs):
super().__init__()
if use_linear_attn: attn_type = "linear"
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.give_pre_end = give_pre_end
self.tanh_out = tanh_out
# compute in_ch_mult, block_in and curr_res at lowest res
in_ch_mult = (1,)+tuple(ch_mult)
block_in = ch*ch_mult[self.num_resolutions-1]
curr_res = resolution // 2**(self.num_resolutions-1)
self.z_shape = (1,z_channels,curr_res,curr_res)
print("Working with z of shape {} = {} dimensions.".format(
self.z_shape, np.prod(self.z_shape)))
# z to block_in
self.conv_in = torch.nn.Conv2d(z_channels,
block_in,
kernel_size=3,
stride=1,
padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
self.mid.block_2 = ResnetBlock(in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout)
# upsampling
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch*ch_mult[i_level]
for i_block in range(self.num_res_blocks+1):
block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
if curr_res in attn_resolutions:
attn.append(make_attn(block_in, attn_type=attn_type))
up = nn.Module()
up.block = block
up.attn = attn
if i_level != 0:
up.upsample = Upsample(block_in, resamp_with_conv)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_ch,
kernel_size=3,
stride=1,
padding=1)
def forward(self, z):
#assert z.shape[1:] == self.z_shape[1:]
self.last_z_shape = z.shape
# timestep embedding
temb = None
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
h = torch.tanh(h)
return h
class SimpleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, *args, **kwargs):
super().__init__()
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
ResnetBlock(in_channels=in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=2 * in_channels,
out_channels=4 * in_channels,
temb_channels=0, dropout=0.0),
ResnetBlock(in_channels=4 * in_channels,
out_channels=2 * in_channels,
temb_channels=0, dropout=0.0),
nn.Conv2d(2*in_channels, in_channels, 1),
Upsample(in_channels, with_conv=True)])
# end
self.norm_out = Normalize(in_channels)
self.conv_out = torch.nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
for i, layer in enumerate(self.model):
if i in [1,2,3]:
x = layer(x, None)
else:
x = layer(x)
h = self.norm_out(x)
h = nonlinearity(h)
x = self.conv_out(h)
return x
class UpsampleDecoder(nn.Module):
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
ch_mult=(2,2), dropout=0.0):
super().__init__()
# upsampling
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
block_in = in_channels
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.res_blocks = nn.ModuleList()
self.upsample_blocks = nn.ModuleList()
for i_level in range(self.num_resolutions):
res_block = []
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
res_block.append(ResnetBlock(in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout))
block_in = block_out
self.res_blocks.append(nn.ModuleList(res_block))
if i_level != self.num_resolutions - 1:
self.upsample_blocks.append(Upsample(block_in, True))
curr_res = curr_res * 2
# end
self.norm_out = Normalize(block_in)
self.conv_out = torch.nn.Conv2d(block_in,
out_channels,
kernel_size=3,
stride=1,
padding=1)
def forward(self, x):
# upsampling
h = x
for k, i_level in enumerate(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.res_blocks[i_level][i_block](h, None)
if i_level != self.num_resolutions - 1:
h = self.upsample_blocks[k](h)
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
return h
class LatentRescaler(nn.Module):
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
super().__init__()
# residual block, interpolate, residual block
self.factor = factor
self.conv_in = nn.Conv2d(in_channels,
mid_channels,
kernel_size=3,
stride=1,
padding=1)
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.attn = AttnBlock(mid_channels)
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
out_channels=mid_channels,
temb_channels=0,
dropout=0.0) for _ in range(depth)])
self.conv_out = nn.Conv2d(mid_channels,
out_channels,
kernel_size=1,
)
def forward(self, x):
x = self.conv_in(x)
for block in self.res_block1:
x = block(x, None)
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
x = self.attn(x)
for block in self.res_block2:
x = block(x, None)
x = self.conv_out(x)
return x
class MergedRescaleEncoder(nn.Module):
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
attn_resolutions, dropout=0.0, resamp_with_conv=True,
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
intermediate_chn = ch * ch_mult[-1]
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
z_channels=intermediate_chn, double_z=False, resolution=resolution,
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
out_ch=None)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
def forward(self, x):
x = self.encoder(x)
x = self.rescaler(x)
return x
class MergedRescaleDecoder(nn.Module):
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
super().__init__()
tmp_chn = z_channels*ch_mult[-1]
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
ch_mult=ch_mult, resolution=resolution, ch=ch)
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
out_channels=tmp_chn, depth=rescale_module_depth)
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Upsampler(nn.Module):
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
super().__init__()
assert out_size >= in_size
num_blocks = int(np.log2(out_size//in_size))+1
factor_up = 1.+ (out_size % in_size)
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
out_channels=in_channels)
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
attn_resolutions=[], in_channels=None, ch=in_channels,
ch_mult=[ch_mult for _ in range(num_blocks)])
def forward(self, x):
x = self.rescaler(x)
x = self.decoder(x)
return x
class Resize(nn.Module):
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
super().__init__()
self.with_conv = learned
self.mode = mode
if self.with_conv:
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
raise NotImplementedError()
assert in_channels is not None
# no asymmetric padding in torch conv, must do it ourselves
self.conv = torch.nn.Conv2d(in_channels,
in_channels,
kernel_size=4,
stride=2,
padding=1)
def forward(self, x, scale_factor=1.0):
if scale_factor==1.0:
return x
else:
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
return x
class FirstStagePostProcessor(nn.Module):
def __init__(self, ch_mult:list, in_channels,
pretrained_model:nn.Module=None,
reshape=False,
n_channels=None,
dropout=0.,
pretrained_config=None):
super().__init__()
if pretrained_config is None:
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.pretrained_model = pretrained_model
else:
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
self.instantiate_pretrained(pretrained_config)
self.do_reshape = reshape
if n_channels is None:
n_channels = self.pretrained_model.encoder.ch
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
stride=1,padding=1)
blocks = []
downs = []
ch_in = n_channels
for m in ch_mult:
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
ch_in = m * n_channels
downs.append(Downsample(ch_in, with_conv=False))
self.model = nn.ModuleList(blocks)
self.downsampler = nn.ModuleList(downs)
def instantiate_pretrained(self, config):
model = instantiate_from_config(config)
self.pretrained_model = model.eval()
# self.pretrained_model.train = False
for param in self.pretrained_model.parameters():
param.requires_grad = False
@torch.no_grad()
def encode_with_pretrained(self,x):
c = self.pretrained_model.encode(x)
if isinstance(c, DiagonalGaussianDistribution):
c = c.mode()
return c
def forward(self,x):
z_fs = self.encode_with_pretrained(x)
z = self.proj_norm(z_fs)
z = self.proj(z)
z = nonlinearity(z)
for submodel, downmodel in zip(self.model,self.downsampler):
z = submodel(z,temb=None)
z = downmodel(z)
if self.do_reshape:
z = rearrange(z,'b c h w -> b (h w) c')
return z

View File

@@ -0,0 +1,996 @@
from abc import abstractmethod
from functools import partial
import math
from typing import Iterable
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
checkpoint,
conv_nd,
linear,
avg_pool_nd,
zero_module,
normalization,
timestep_embedding,
)
from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer
from modelscope.models.cv.image_to_3d.ldm.util import exists
# dummy replace
def convert_module_to_f16(x):
pass
def convert_module_to_f32(x):
pass
## go
class AttentionPool2d(nn.Module):
"""
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
"""
def __init__(
self,
spacial_dim: int,
embed_dim: int,
num_heads_channels: int,
output_dim: int = None,
):
super().__init__()
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
self.num_heads = embed_dim // num_heads_channels
self.attention = QKVAttention(self.num_heads)
def forward(self, x):
b, c, *_spatial = x.shape
x = x.reshape(b, c, -1) # NC(HW)
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
x = self.qkv_proj(x)
x = self.attention(x)
x = self.c_proj(x)
return x[:, :, 0]
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb, context=None):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
elif isinstance(layer, SpatialTransformer):
x = layer(x, context)
else:
x = layer(x)
return x
class Upsample(nn.Module):
"""
An upsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
upsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
if use_conv:
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
def forward(self, x):
assert x.shape[1] == self.channels
if self.dims == 3:
x = F.interpolate(
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
)
else:
x = F.interpolate(x, scale_factor=2, mode="nearest")
if self.use_conv:
x = self.conv(x)
return x
class TransposedUpsample(nn.Module):
'Learned 2x upsampling without padding'
def __init__(self, channels, out_channels=None, ks=5):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
def forward(self,x):
return self.up(x)
class Downsample(nn.Module):
"""
A downsampling layer with an optional convolution.
:param channels: channels in the inputs and outputs.
:param use_conv: a bool determining if a convolution is applied.
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
downsampling occurs in the inner-two dimensions.
"""
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.dims = dims
stride = 2 if dims != 3 else (1, 2, 2)
if use_conv:
self.op = conv_nd(
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
)
else:
assert self.channels == self.out_channels
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
def forward(self, x):
assert x.shape[1] == self.channels
return self.op(x)
class ResBlock(TimestepBlock):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
:param emb_channels: the number of timestep embedding channels.
:param dropout: the rate of dropout.
:param out_channels: if specified, the number of out channels.
:param use_conv: if True and out_channels is specified, use a spatial
convolution instead of a smaller 1x1 convolution to change the
channels in the skip connection.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param use_checkpoint: if True, use gradient checkpointing on this module.
:param up: if True, use this block for upsampling.
:param down: if True, use this block for downsampling.
"""
def __init__(
self,
channels,
emb_channels,
dropout,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
dims=2,
use_checkpoint=False,
up=False,
down=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_checkpoint = use_checkpoint
self.use_scale_shift_norm = use_scale_shift_norm
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
conv_nd(dims, channels, self.out_channels, 3, padding=1),
)
self.updown = up or down
if up:
self.h_upd = Upsample(channels, False, dims)
self.x_upd = Upsample(channels, False, dims)
elif down:
self.h_upd = Downsample(channels, False, dims)
self.x_upd = Downsample(channels, False, dims)
else:
self.h_upd = self.x_upd = nn.Identity()
self.emb_layers = nn.Sequential(
nn.SiLU(),
linear(
emb_channels,
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = conv_nd(
dims, channels, self.out_channels, 3, padding=1
)
else:
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
def forward(self, x, emb):
"""
Apply the block to a Tensor, conditioned on a timestep embedding.
:param x: an [N x C x ...] Tensor of features.
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
:return: an [N x C x ...] Tensor of outputs.
"""
return checkpoint(
self._forward, (x, emb), self.parameters(), self.use_checkpoint
)
def _forward(self, x, emb):
if self.updown:
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
h = in_rest(x)
h = self.h_upd(h)
x = self.x_upd(x)
h = in_conv(h)
else:
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm: # False
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
use_checkpoint=False,
use_new_attention_order=False,
):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert (
channels % num_head_channels == 0
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
if use_new_attention_order:
# split qkv before split heads
self.attention = QKVAttention(self.num_heads)
else:
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
#return pt_checkpoint(self._forward, x) # pytorch
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
def count_flops_attn(model, _x, y):
"""
A counter for the `thop` package to count the operations in an
attention operation.
Meant to be used like:
macs, params = thop.profile(
model,
inputs=(inputs, timestamps),
custom_ops={QKVAttention: QKVAttention.count_flops},
)
"""
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention and splits in a different order.
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.chunk(3, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts",
(q * scale).view(bs * self.n_heads, ch, length),
(k * scale).view(bs * self.n_heads, ch, length),
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
return a.reshape(bs, -1, length)
@staticmethod
def count_flops(model, _x, y):
return count_flops_attn(model, _x, y)
class UNetModel(nn.Module):
"""
The full UNet model with attention and timestep embedding.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
attention will take place. May be a set, list, or tuple.
For example, if this contains 4, then at 4x downsampling, attention
will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
:param num_heads_channels: if specified, ignore num_heads and instead use
a fixed channel width per attention head.
:param num_heads_upsample: works with num_heads to set a different number
of heads for upsampling. Deprecated.
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
:param resblock_updown: use residual blocks for up/downsampling.
:param use_new_attention_order: use a different attention pattern for potentially
increased efficiency.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
num_classes=None,
use_checkpoint=False,
use_fp16=False,
num_heads=-1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
use_spatial_transformer=False, # custom transformer support
transformer_depth=1, # custom transformer support
context_dim=None, # custom transformer support
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
legacy=True,
disable_self_attentions=None,
num_attention_blocks=None
):
super().__init__()
if use_spatial_transformer:
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
if context_dim is not None:
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
from omegaconf.listconfig import ListConfig
if type(context_dim) == ListConfig:
context_dim = list(context_dim)
if num_heads_upsample == -1:
num_heads_upsample = num_heads
if num_heads == -1:
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
if num_head_channels == -1:
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
self.image_size = image_size
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
if isinstance(num_res_blocks, int):
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
else:
if len(num_res_blocks) != len(channel_mult):
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
"as a list/tuple (per-level) with the same length as channel_mult")
self.num_res_blocks = num_res_blocks
#self.num_res_blocks = num_res_blocks
if disable_self_attentions is not None:
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
assert len(disable_self_attentions) == len(channel_mult)
if num_attention_blocks is not None:
assert len(num_attention_blocks) == len(self.num_res_blocks)
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
f"attention will still not be set.") # todo: convert to warning
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.num_classes = num_classes
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
self.predict_codebook_ids = n_embed is not None
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
if self.num_classes is not None:
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
) # 0
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for nr in range(self.num_res_blocks[level]):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions: # always True
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.output_blocks = nn.ModuleList([])
for level, mult in list(enumerate(channel_mult))[::-1]:
for i in range(self.num_res_blocks[level] + 1):
ich = input_block_chans.pop()
layers = [
ResBlock(
ch + ich,
time_embed_dim,
dropout,
out_channels=model_channels * mult,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = model_channels * mult
if ds in attention_resolutions:
if num_head_channels == -1:
dim_head = ch // num_heads
else:
num_heads = ch // num_head_channels
dim_head = num_head_channels
if legacy:
#num_heads = 1
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
if exists(disable_self_attentions):
disabled_sa = disable_self_attentions[level]
else:
disabled_sa = False
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads_upsample,
num_head_channels=dim_head,
use_new_attention_order=use_new_attention_order,
) if not use_spatial_transformer else SpatialTransformer(
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
disable_self_attn=disabled_sa
)
)
if level and i == self.num_res_blocks[level]:
out_ch = ch
layers.append(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
up=True,
)
if resblock_updown
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
)
ds //= 2
self.output_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
)
if self.predict_codebook_ids:
self.id_predictor = nn.Sequential(
normalization(ch),
conv_nd(dims, model_channels, n_embed, 1),
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
)
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
self.output_blocks.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
self.output_blocks.apply(convert_module_to_f32)
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:param context: conditioning plugged in via crossattn
:param y: an [N] Tensor of labels, if class-conditional.
:return: an [N x C x ...] Tensor of outputs.
"""
assert (y is not None) == (
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N
emb = self.time_embed(t_emb) #
if self.num_classes is not None:
assert y.shape == (x.shape[0],)
emb = emb + self.label_emb(y)
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb, context) # conv
hs.append(h)
h = self.middle_block(h, emb, context)
for module in self.output_blocks:
h = th.cat([h, hs.pop()], dim=1)
h = module(h, emb, context)
h = h.type(x.dtype)
if self.predict_codebook_ids:
return self.id_predictor(h)
else:
return self.out(h)
class EncoderUNetModel(nn.Module):
"""
The half UNet model with attention and timestep embedding.
For usage, see UNet.
"""
def __init__(
self,
image_size,
in_channels,
model_channels,
out_channels,
num_res_blocks,
attention_resolutions,
dropout=0,
channel_mult=(1, 2, 4, 8),
conv_resample=True,
dims=2,
use_checkpoint=False,
use_fp16=False,
num_heads=1,
num_head_channels=-1,
num_heads_upsample=-1,
use_scale_shift_norm=False,
resblock_updown=False,
use_new_attention_order=False,
pool="adaptive",
*args,
**kwargs
):
super().__init__()
if num_heads_upsample == -1:
num_heads_upsample = num_heads
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.attention_resolutions = attention_resolutions
self.dropout = dropout
self.channel_mult = channel_mult
self.conv_resample = conv_resample
self.use_checkpoint = use_checkpoint
self.dtype = th.float16 if use_fp16 else th.float32
self.num_heads = num_heads
self.num_head_channels = num_head_channels
self.num_heads_upsample = num_heads_upsample
time_embed_dim = model_channels * 4
self.time_embed = nn.Sequential(
linear(model_channels, time_embed_dim),
nn.SiLU(),
linear(time_embed_dim, time_embed_dim),
)
self.input_blocks = nn.ModuleList(
[
TimestepEmbedSequential(
conv_nd(dims, in_channels, model_channels, 3, padding=1)
)
]
)
self._feature_size = model_channels
input_block_chans = [model_channels]
ch = model_channels
ds = 1
for level, mult in enumerate(channel_mult):
for _ in range(num_res_blocks):
layers = [
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=mult * model_channels,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
)
]
ch = mult * model_channels
if ds in attention_resolutions:
layers.append(
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
)
)
self.input_blocks.append(TimestepEmbedSequential(*layers))
self._feature_size += ch
input_block_chans.append(ch)
if level != len(channel_mult) - 1:
out_ch = ch
self.input_blocks.append(
TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
out_channels=out_ch,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
down=True,
)
if resblock_updown
else Downsample(
ch, conv_resample, dims=dims, out_channels=out_ch
)
)
)
ch = out_ch
input_block_chans.append(ch)
ds *= 2
self._feature_size += ch
self.middle_block = TimestepEmbedSequential(
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
AttentionBlock(
ch,
use_checkpoint=use_checkpoint,
num_heads=num_heads,
num_head_channels=num_head_channels,
use_new_attention_order=use_new_attention_order,
),
ResBlock(
ch,
time_embed_dim,
dropout,
dims=dims,
use_checkpoint=use_checkpoint,
use_scale_shift_norm=use_scale_shift_norm,
),
)
self._feature_size += ch
self.pool = pool
if pool == "adaptive":
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
nn.AdaptiveAvgPool2d((1, 1)),
zero_module(conv_nd(dims, ch, out_channels, 1)),
nn.Flatten(),
)
elif pool == "attention":
assert num_head_channels != -1
self.out = nn.Sequential(
normalization(ch),
nn.SiLU(),
AttentionPool2d(
(image_size // ds), ch, num_head_channels, out_channels
),
)
elif pool == "spatial":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
nn.ReLU(),
nn.Linear(2048, self.out_channels),
)
elif pool == "spatial_v2":
self.out = nn.Sequential(
nn.Linear(self._feature_size, 2048),
normalization(2048),
nn.SiLU(),
nn.Linear(2048, self.out_channels),
)
else:
raise NotImplementedError(f"Unexpected {pool} pooling")
def convert_to_fp16(self):
"""
Convert the torso of the model to float16.
"""
self.input_blocks.apply(convert_module_to_f16)
self.middle_block.apply(convert_module_to_f16)
def convert_to_fp32(self):
"""
Convert the torso of the model to float32.
"""
self.input_blocks.apply(convert_module_to_f32)
self.middle_block.apply(convert_module_to_f32)
def forward(self, x, timesteps):
"""
Apply the model to an input batch.
:param x: an [N x C x ...] Tensor of inputs.
:param timesteps: a 1-D batch of timesteps.
:return: an [N x K] Tensor of outputs.
"""
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
results = []
h = x.type(self.dtype)
for module in self.input_blocks:
h = module(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = self.middle_block(h, emb)
if self.pool.startswith("spatial"):
results.append(h.type(x.dtype).mean(dim=(2, 3)))
h = th.cat(results, axis=-1)
return self.out(h)
else:
h = h.type(x.dtype)
return self.out(h)

View File

@@ -0,0 +1,267 @@
# adopted from
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
# and
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
# and
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
#
# thanks!
import os
import math
import torch
import torch.nn as nn
import numpy as np
from einops import repeat
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
if schedule == "linear":
betas = (
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
)
elif schedule == "cosine":
timesteps = (
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
)
alphas = timesteps / (1 + cosine_s) * np.pi / 2
alphas = torch.cos(alphas).pow(2)
alphas = alphas / alphas[0]
betas = 1 - alphas[1:] / alphas[:-1]
betas = np.clip(betas, a_min=0, a_max=0.999)
elif schedule == "sqrt_linear":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
elif schedule == "sqrt":
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
else:
raise ValueError(f"schedule '{schedule}' unknown.")
return betas.numpy()
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
if ddim_discr_method == 'uniform':
c = num_ddpm_timesteps // num_ddim_timesteps
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
elif ddim_discr_method == 'quad':
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
else:
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
# add one to get the final alpha values right (the ones from first scale to data during sampling)
steps_out = ddim_timesteps + 1
if verbose:
print(f'Selected timesteps for ddim sampler: {steps_out}')
return steps_out
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
# select alphas for computing the variance schedule
alphas = alphacums[ddim_timesteps]
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
# according the the formula provided in https://arxiv.org/abs/2010.02502
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
if verbose:
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
print(f'For the chosen value of eta, which is {eta}, '
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
return sigmas, alphas, alphas_prev
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function,
which defines the cumulative product of (1-beta) over time from t = [0,1].
:param num_diffusion_timesteps: the number of betas to produce.
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
produces the cumulative product of (1-beta) up to that
part of the diffusion process.
:param max_beta: the maximum beta to use; use values lower than 1 to
prevent singularities.
"""
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return np.array(betas)
def extract_into_tensor(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
class CheckpointFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, run_function, length, *args):
ctx.run_function = run_function
ctx.input_tensors = list(args[:length])
ctx.input_params = list(args[length:])
with torch.no_grad():
output_tensors = ctx.run_function(*ctx.input_tensors)
return output_tensors
@staticmethod
def backward(ctx, *output_grads):
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
with torch.enable_grad():
# Fixes a bug where the first op in run_function modifies the
# Tensor storage in place, which is not allowed for detach()'d
# Tensors.
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
output_tensors = ctx.run_function(*shallow_copies)
input_grads = torch.autograd.grad(
output_tensors,
ctx.input_tensors + ctx.input_params,
output_grads,
allow_unused=True,
)
del ctx.input_tensors
del ctx.input_params
del output_tensors
return (None, None) + input_grads
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def scale_module(module, scale):
"""
Scale the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().mul_(scale)
return module
def mean_flat(tensor):
"""
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def normalization(channels):
"""
Make a standard normalization layer.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
return GroupNorm32(32, channels)
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
class SiLU(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def conv_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D convolution module.
"""
if dims == 1:
return nn.Conv1d(*args, **kwargs)
elif dims == 2:
return nn.Conv2d(*args, **kwargs)
elif dims == 3:
return nn.Conv3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
def linear(*args, **kwargs):
"""
Create a linear module.
"""
return nn.Linear(*args, **kwargs)
def avg_pool_nd(dims, *args, **kwargs):
"""
Create a 1D, 2D, or 3D average pooling module.
"""
if dims == 1:
return nn.AvgPool1d(*args, **kwargs)
elif dims == 2:
return nn.AvgPool2d(*args, **kwargs)
elif dims == 3:
return nn.AvgPool3d(*args, **kwargs)
raise ValueError(f"unsupported dimensions: {dims}")
class HybridConditioner(nn.Module):
def __init__(self, c_concat_config, c_crossattn_config):
super().__init__()
self.concat_conditioner = instantiate_from_config(c_concat_config)
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
def forward(self, c_concat, c_crossattn):
c_concat = self.concat_conditioner(c_concat)
c_crossattn = self.crossattn_conditioner(c_crossattn)
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

View File

@@ -0,0 +1,92 @@
import torch
import numpy as np
class AbstractDistribution:
def sample(self):
raise NotImplementedError()
def mode(self):
raise NotImplementedError()
class DiracDistribution(AbstractDistribution):
def __init__(self, value):
self.value = value
def sample(self):
return self.value
def mode(self):
return self.value
class DiagonalGaussianDistribution(object):
def __init__(self, parameters, deterministic=False):
self.parameters = parameters
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
self.deterministic = deterministic
self.std = torch.exp(0.5 * self.logvar)
self.var = torch.exp(self.logvar)
if self.deterministic:
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
def sample(self):
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
return x
def kl(self, other=None):
if self.deterministic:
return torch.Tensor([0.])
else:
if other is None:
return 0.5 * torch.sum(torch.pow(self.mean, 2)
+ self.var - 1.0 - self.logvar,
dim=[1, 2, 3])
else:
return 0.5 * torch.sum(
torch.pow(self.mean - other.mean, 2) / other.var
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
dim=[1, 2, 3])
def nll(self, sample, dims=[1,2,3]):
if self.deterministic:
return torch.Tensor([0.])
logtwopi = np.log(2.0 * np.pi)
return 0.5 * torch.sum(
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
dim=dims)
def mode(self):
return self.mean
def normal_kl(mean1, logvar1, mean2, logvar2):
"""
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
Compute the KL divergence between two gaussians.
Shapes are automatically broadcasted, so batches can be compared to
scalars, among other use cases.
"""
tensor = None
for obj in (mean1, logvar1, mean2, logvar2):
if isinstance(obj, torch.Tensor):
tensor = obj
break
assert tensor is not None, "at least one argument must be a Tensor"
# Force variances to be Tensors. Broadcasting helps convert scalars to
# Tensors, but it does not work for torch.exp().
logvar1, logvar2 = [
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
for x in (logvar1, logvar2)
]
return 0.5 * (
-1.0
+ logvar2
- logvar1
+ torch.exp(logvar1 - logvar2)
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
)

View File

@@ -0,0 +1 @@
from .clip import *

View File

@@ -0,0 +1,200 @@
import hashlib
import os
import urllib
import warnings
from typing import Any, Union, List
from pkg_resources import packaging
import torch
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
__all__ = ["available_models", "load"]
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
}
def _download(url: str, root: str):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _convert_image_to_rgb(image):
return image.convert("RGB")
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
_convert_image_to_rgb,
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model or more hackable non-JIT model (default).
download_root: str
path to download the model files; by default, it uses "~/.cache/clip"
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
with open(model_path, 'rb') as opened_file:
try:
# loading JIT archive
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(opened_file, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def _node_get(node: torch._C.Node, key: str):
"""Gets attributes of a node which is polymorphic over return type.
From https://github.com/pytorch/pytorch/pull/82628
"""
sel = node.kindOf(key)
return getattr(node, sel)(key)
def patch_device(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
try:
graphs = [module.graph] if hasattr(module, "graph") else []
except RuntimeError:
graphs = []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if _node_get(inputs[i].node(), "value") == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())

View File

@@ -0,0 +1,436 @@
from collections import OrderedDict
from typing import Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.relu2 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu3 = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu1(self.bn1(self.conv1(x)))
out = self.relu2(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu3(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x[:1], key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x.squeeze(0)
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisionTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def forward(self, x: torch.Tensor):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])
if self.proj is not None:
x = x @ self.proj
return x
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisionTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=1, keepdim=True)
text_features = text_features / text_features.norm(dim=1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
convert_weights(model)
model.load_state_dict(state_dict)
return model.eval()

View File

@@ -0,0 +1,132 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

View File

@@ -0,0 +1,551 @@
import torch
import torch.nn as nn
import numpy as np
from functools import partial
import kornia
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
from modelscope.models.cv.image_to_3d.ldm.util import default
# import clip
from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip
class AbstractEncoder(nn.Module):
def __init__(self):
super().__init__()
def encode(self, *args, **kwargs):
raise NotImplementedError
class IdentityEncoder(AbstractEncoder):
def encode(self, x):
return x
class FaceClipEncoder(AbstractEncoder):
def __init__(self, augment=True, retreival_key=None):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
self.augment = augment
self.retreival_key = retreival_key
def forward(self, img):
encodings = []
with torch.no_grad():
x_offset = 125
if self.retreival_key:
# Assumes retrieved image are packed into the second half of channels
face = img[:,3:,190:440,x_offset:(512-x_offset)]
other = img[:,:3,...].clone()
else:
face = img[:,:,190:440,x_offset:(512-x_offset)]
other = img.clone()
if self.augment:
face = K.RandomHorizontalFlip()(face)
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
encodings = [
self.encoder.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class FaceIdClipEncoder(AbstractEncoder):
def __init__(self):
super().__init__()
self.encoder = FrozenCLIPImageEmbedder()
for p in self.encoder.parameters():
p.requires_grad = False
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
def forward(self, img):
encodings = []
with torch.no_grad():
face = kornia.geometry.resize(img, (256, 256),
interpolation='bilinear', align_corners=True)
other = img.clone()
other[:,:,184:452,122:396] *= 0
encodings = [
self.id.encode(face),
self.encoder.encode(other),
]
return torch.cat(encodings, dim=1)
def encode(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
return self(img)
class ClassEmbedder(nn.Module):
def __init__(self, embed_dim, n_classes=1000, key='class'):
super().__init__()
self.key = key
self.embedding = nn.Embedding(n_classes, embed_dim)
def forward(self, batch, key=None):
if key is None:
key = self.key
# this is for use in crossattn
c = batch[key][:, None]
c = self.embedding(c)
return c
class TransformerEmbedder(AbstractEncoder):
"""Some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
super().__init__()
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer))
def forward(self, tokens):
tokens = tokens.to(self.device) # meh
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, x):
return self(x)
class BERTTokenizer(AbstractEncoder):
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
def __init__(self, device="cuda", vq_interface=True, max_length=77):
super().__init__()
from transformers import BertTokenizerFast # TODO: add to reuquirements
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
self.device = device
self.vq_interface = vq_interface
self.max_length = max_length
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
return tokens
@torch.no_grad()
def encode(self, text):
tokens = self(text)
if not self.vq_interface:
return tokens
return None, None, [None, None, tokens]
def decode(self, text):
return text
class BERTEmbedder(AbstractEncoder):
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
super().__init__()
self.use_tknz_fn = use_tokenizer
if self.use_tknz_fn:
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
self.device = device
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
attn_layers=Encoder(dim=n_embed, depth=n_layer),
emb_dropout=embedding_dropout)
def forward(self, text):
if self.use_tknz_fn:
tokens = self.tknz_fn(text)#.to(self.device)
else:
tokens = text
z = self.transformer(tokens, return_embeddings=True)
return z
def encode(self, text):
# output of length 77
return self(text)
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
class FrozenT5Embedder(AbstractEncoder):
"""Uses the T5 transformer encoder for text"""
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
super().__init__()
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
import kornia.augmentation as K
class FrozenFaceEncoder(AbstractEncoder):
def __init__(self, model_path, augment=False):
super().__init__()
self.loss_fn = IDFeatures(model_path)
# face encoder is frozen
for p in self.loss_fn.parameters():
p.requires_grad = False
# Mapper is trainable
self.mapper = torch.nn.Linear(512, 768)
p = 0.25
if augment:
self.augment = K.AugmentationSequential(
K.RandomHorizontalFlip(p=0.5),
K.RandomEqualize(p=p),
# K.RandomPlanckianJitter(p=p),
# K.RandomPlasmaBrightness(p=p),
# K.RandomPlasmaContrast(p=p),
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
)
else:
self.augment = False
def forward(self, img):
if isinstance(img, list):
# Uncondition
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
if self.augment is not None:
# Transforms require 0-1
img = self.augment((img + 1)/2)
img = 2*img - 1
feat = self.loss_fn(img, crop=True)
feat = self.mapper(feat.unsqueeze(1))
return feat
def encode(self, img):
return self(img)
class FrozenCLIPEmbedder(AbstractEncoder):
"""Uses the CLIP transformer encoder for text (from huggingface)"""
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
self.device = device
self.max_length = max_length # TODO: typical value?
self.freeze()
def freeze(self):
self.transformer = self.transformer.eval()
#self.train = disabled_train
for param in self.parameters():
param.requires_grad = False
def forward(self, text):
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
tokens = batch_encoding["input_ids"].to(self.device)
outputs = self.transformer(input_ids=tokens)
z = outputs.last_hidden_state
return z
def encode(self, text):
return self(text)
import torch.nn.functional as F
from transformers import CLIPVisionModel
class ClipImageProjector(AbstractEncoder):
"""
Uses the CLIP image encoder.
"""
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
super().__init__()
self.model = CLIPVisionModel.from_pretrained(version)
self.model.train()
self.max_length = max_length # TODO: typical value?
self.antialias = True
self.mapper = torch.nn.Linear(1024, 768)
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
null_cond = self.get_null_cond(version, max_length)
self.register_buffer('null_cond', null_cond)
@torch.no_grad()
def get_null_cond(self, version, max_length):
device = self.mean.device
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
null_cond = embedder([""])
return null_cond
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
if isinstance(x, list):
return self.null_cond
# x is assumed to be in range [-1,1]
x = self.preprocess(x)
outputs = self.model(pixel_values=x)
last_hidden_state = outputs.last_hidden_state
last_hidden_state = self.mapper(last_hidden_state)
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
def encode(self, im):
return self(im)
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
super().__init__()
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
self.projection = torch.nn.Linear(768, 768)
def forward(self, text):
z = self.embedder(text)
return self.projection(z)
def encode(self, text):
return self(text)
class FrozenCLIPImageEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=False,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
def preprocess(self, x):
# Expects inputs in the range -1, 1
x = kornia.geometry.resize(x, (224, 224),
interpolation='bicubic',align_corners=True,
antialias=self.antialias)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, 768, device=device)
return self.model.encode_image(self.preprocess(x)).float()
def encode(self, im):
return self(im).unsqueeze(1)
from torchvision import transforms
import random
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
"""
Uses the CLIP image encoder.
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
"""
def __init__(
self,
model='ViT-L/14',
jit=False,
device='cpu',
antialias=True,
max_crops=5,
):
super().__init__()
self.model, _ = clip.load(name=model, device=device, jit=jit)
# We don't use the text part so delete it
del self.model.transformer
self.antialias = antialias
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
self.max_crops = max_crops
def preprocess(self, x):
# Expects inputs in the range -1, 1
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
max_crops = self.max_crops
patches = []
crops = [randcrop(x) for _ in range(max_crops)]
patches.extend(crops)
x = torch.cat(patches, dim=0)
x = (x + 1.) / 2.
# renormalize according to clip
x = kornia.enhance.normalize(x, self.mean, self.std)
return x
def forward(self, x):
# x is assumed to be in range [-1,1]
if isinstance(x, list):
# [""] denotes condition dropout for ucg
device = self.model.visual.conv1.weight.device
return torch.zeros(1, self.max_crops, 768, device=device)
batch_tokens = []
for im in x:
patches = self.preprocess(im.unsqueeze(0))
tokens = self.model.encode_image(patches).float()
for t in tokens:
if random.random() < 0.1:
t *= 0
batch_tokens.append(tokens.unsqueeze(0))
return torch.cat(batch_tokens, dim=0)
def encode(self, im):
return self(im)
class SpatialRescaler(nn.Module):
def __init__(self,
n_stages=1,
method='bilinear',
multiplier=0.5,
in_channels=3,
out_channels=None,
bias=False):
super().__init__()
self.n_stages = n_stages
assert self.n_stages >= 0
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
self.multiplier = multiplier
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
self.remap_output = out_channels is not None
if self.remap_output:
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
def forward(self,x):
for stage in range(self.n_stages):
x = self.interpolator(x, scale_factor=self.multiplier)
if self.remap_output:
x = self.channel_mapper(x)
return x
def encode(self, x):
return self(x)
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
class LowScaleEncoder(nn.Module):
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
scale_factor=1.0):
super().__init__()
self.max_noise_level = max_noise_level
self.model = instantiate_from_config(model_config)
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
linear_end=linear_end)
self.out_size = output_size
self.scale_factor = scale_factor
def register_schedule(self, beta_schedule="linear", timesteps=1000,
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
cosine_s=cosine_s)
alphas = 1. - betas
alphas_cumprod = np.cumprod(alphas, axis=0)
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
self.linear_start = linear_start
self.linear_end = linear_end
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
to_torch = partial(torch.tensor, dtype=torch.float32)
self.register_buffer('betas', to_torch(betas))
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
# calculations for diffusion q(x_t | x_{t-1}) and others
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
def forward(self, x):
z = self.model.encode(x).sample()
z = z * self.scale_factor
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
z = self.q_sample(z, noise_level)
if self.out_size is not None:
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
return z, noise_level
def decode(self, z):
z = z / self.scale_factor
return self.model.decode(z)
if __name__ == "__main__":
from ldm.util import count_params
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
model = FrozenCLIPEmbedder().cuda()
count_params(model, True)
z = model(sentences)
print(z.shape)
print("done.")

View File

@@ -0,0 +1,641 @@
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
import torch
from torch import nn, einsum
import torch.nn.functional as F
from functools import partial
from inspect import isfunction
from collections import namedtuple
from einops import rearrange, repeat, reduce
# constants
DEFAULT_DIM_HEAD = 64
Intermediates = namedtuple('Intermediates', [
'pre_softmax_attn',
'post_softmax_attn'
])
LayerIntermediates = namedtuple('Intermediates', [
'hiddens',
'attn_intermediates'
])
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.emb = nn.Embedding(max_seq_len, dim)
self.init_()
def init_(self):
nn.init.normal_(self.emb.weight, std=0.02)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
return self.emb(n)[None, :, :]
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer('inv_freq', inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return emb[None, :, :]
# helpers
def exists(val):
return val is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def always(val):
def inner(*args, **kwargs):
return val
return inner
def not_equals(val):
def inner(x):
return x != val
return inner
def equals(val):
def inner(x):
return x == val
return inner
def max_neg_value(tensor):
return -torch.finfo(tensor.dtype).max
# keyword argument helpers
def pick_and_pop(keys, d):
values = list(map(lambda key: d.pop(key), keys))
return dict(zip(keys, values))
def group_dict_by_key(cond, d):
return_val = [dict(), dict()]
for key in d.keys():
match = bool(cond(key))
ind = int(not match)
return_val[ind][key] = d[key]
return (*return_val,)
def string_begins_with(prefix, str):
return str.startswith(prefix)
def group_by_key_prefix(prefix, d):
return group_dict_by_key(partial(string_begins_with, prefix), d)
def groupby_prefix_and_trim(prefix, d):
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
return kwargs_without_prefix, kwargs
# classes
class Scale(nn.Module):
def __init__(self, value, fn):
super().__init__()
self.value = value
self.fn = fn
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.value, *rest)
class Rezero(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
self.g = nn.Parameter(torch.zeros(1))
def forward(self, x, **kwargs):
x, *rest = self.fn(x, **kwargs)
return (x * self.g, *rest)
class ScaleNorm(nn.Module):
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RMSNorm(nn.Module):
def __init__(self, dim, eps=1e-8):
super().__init__()
self.scale = dim ** -0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(dim))
def forward(self, x):
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class Residual(nn.Module):
def forward(self, x, residual):
return x + residual
class GRUGating(nn.Module):
def __init__(self, dim):
super().__init__()
self.gru = nn.GRUCell(dim, dim)
def forward(self, x, residual):
gated_output = self.gru(
rearrange(x, 'b n d -> (b n) d'),
rearrange(residual, 'b n d -> (b n) d')
)
return gated_output.reshape_as(x)
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
# attention.
class Attention(nn.Module):
def __init__(
self,
dim,
dim_head=DEFAULT_DIM_HEAD,
heads=8,
causal=False,
mask=None,
talking_heads=False,
sparse_topk=None,
use_entmax15=False,
num_mem_kv=0,
dropout=0.,
on_attn=False
):
super().__init__()
if use_entmax15:
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
self.scale = dim_head ** -0.5
self.heads = heads
self.causal = causal
self.mask = mask
inner_dim = dim_head * heads
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.dropout = nn.Dropout(dropout)
# talking heads
self.talking_heads = talking_heads
if talking_heads:
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
# explicit topk sparse attention
self.sparse_topk = sparse_topk
# entmax
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
self.attn_fn = F.softmax
# add memory key / values
self.num_mem_kv = num_mem_kv
if num_mem_kv > 0:
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
# attention on attention
self.attn_on_attn = on_attn
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
rel_pos=None,
sinusoidal_emb=None,
prev_attn=None,
mem=None
):
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
kv_input = default(context, x)
q_input = x
k_input = kv_input
v_input = kv_input
if exists(mem):
k_input = torch.cat((mem, k_input), dim=-2)
v_input = torch.cat((mem, v_input), dim=-2)
if exists(sinusoidal_emb):
# in shortformer, the query would start at a position offset depending on the past cached memory
offset = k_input.shape[-2] - q_input.shape[-2]
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
k_input = k_input + sinusoidal_emb(k_input)
q = self.to_q(q_input)
k = self.to_k(k_input)
v = self.to_v(v_input)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
input_mask = None
if any(map(exists, (mask, context_mask))):
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
k_mask = q_mask if not exists(context) else context_mask
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
q_mask = rearrange(q_mask, 'b i -> b () i ()')
k_mask = rearrange(k_mask, 'b j -> b () () j')
input_mask = q_mask * k_mask
if self.num_mem_kv > 0:
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
k = torch.cat((mem_k, k), dim=-2)
v = torch.cat((mem_v, v), dim=-2)
if exists(input_mask):
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
mask_value = max_neg_value(dots)
if exists(prev_attn):
dots = dots + prev_attn
pre_softmax_attn = dots
if talking_heads:
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
if exists(rel_pos):
dots = rel_pos(dots)
if exists(input_mask):
dots.masked_fill_(~input_mask, mask_value)
del input_mask
if self.causal:
i, j = dots.shape[-2:]
r = torch.arange(i, device=device)
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
mask = F.pad(mask, (j - i, 0), value=False)
dots.masked_fill_(mask, mask_value)
del mask
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
top, _ = dots.topk(self.sparse_topk, dim=-1)
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
mask = dots < vk
dots.masked_fill_(mask, mask_value)
del mask
attn = self.attn_fn(dots, dim=-1)
post_softmax_attn = attn
attn = self.dropout(attn)
if talking_heads:
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
intermediates = Intermediates(
pre_softmax_attn=pre_softmax_attn,
post_softmax_attn=post_softmax_attn
)
return self.to_out(out), intermediates
class AttentionLayers(nn.Module):
def __init__(
self,
dim,
depth,
heads=8,
causal=False,
cross_attend=False,
only_cross=False,
use_scalenorm=False,
use_rmsnorm=False,
use_rezero=False,
rel_pos_num_buckets=32,
rel_pos_max_distance=128,
position_infused_attn=False,
custom_layers=None,
sandwich_coef=None,
par_ratio=None,
residual_attn=False,
cross_residual_attn=False,
macaron=False,
pre_norm=True,
gate_residual=False,
**kwargs
):
super().__init__()
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
self.dim = dim
self.depth = depth
self.layers = nn.ModuleList([])
self.has_pos_emb = position_infused_attn
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
self.rotary_pos_emb = always(None)
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
self.rel_pos = None
self.pre_norm = pre_norm
self.residual_attn = residual_attn
self.cross_residual_attn = cross_residual_attn
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
norm_class = RMSNorm if use_rmsnorm else norm_class
norm_fn = partial(norm_class, dim)
norm_fn = nn.Identity if use_rezero else norm_fn
branch_fn = Rezero if use_rezero else None
if cross_attend and not only_cross:
default_block = ('a', 'c', 'f')
elif cross_attend and only_cross:
default_block = ('c', 'f')
else:
default_block = ('a', 'f')
if macaron:
default_block = ('f',) + default_block
if exists(custom_layers):
layer_types = custom_layers
elif exists(par_ratio):
par_depth = depth * len(default_block)
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
default_block = tuple(filter(not_equals('f'), default_block))
par_attn = par_depth // par_ratio
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
par_width = (depth_cut + depth_cut // par_attn) // par_attn
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
par_block = default_block + ('f',) * (par_width - len(default_block))
par_head = par_block * par_attn
layer_types = par_head + ('f',) * (par_depth - len(par_head))
elif exists(sandwich_coef):
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
else:
layer_types = default_block * depth
self.layer_types = layer_types
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
for layer_type in self.layer_types:
if layer_type == 'a':
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
elif layer_type == 'c':
layer = Attention(dim, heads=heads, **attn_kwargs)
elif layer_type == 'f':
layer = FeedForward(dim, **ff_kwargs)
layer = layer if not macaron else Scale(0.5, layer)
else:
raise Exception(f'invalid layer type {layer_type}')
if isinstance(layer, Attention) and exists(branch_fn):
layer = branch_fn(layer)
if gate_residual:
residual_fn = GRUGating(dim)
else:
residual_fn = Residual()
self.layers.append(nn.ModuleList([
norm_fn(),
layer,
residual_fn
]))
def forward(
self,
x,
context=None,
mask=None,
context_mask=None,
mems=None,
return_hiddens=False
):
hiddens = []
intermediates = []
prev_attn = None
prev_cross_attn = None
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
is_last = ind == (len(self.layers) - 1)
if layer_type == 'a':
hiddens.append(x)
layer_mem = mems.pop(0)
residual = x
if self.pre_norm:
x = norm(x)
if layer_type == 'a':
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
prev_attn=prev_attn, mem=layer_mem)
elif layer_type == 'c':
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
elif layer_type == 'f':
out = block(x)
x = residual_fn(out, residual)
if layer_type in ('a', 'c'):
intermediates.append(inter)
if layer_type == 'a' and self.residual_attn:
prev_attn = inter.pre_softmax_attn
elif layer_type == 'c' and self.cross_residual_attn:
prev_cross_attn = inter.pre_softmax_attn
if not self.pre_norm and not is_last:
x = norm(x)
if return_hiddens:
intermediates = LayerIntermediates(
hiddens=hiddens,
attn_intermediates=intermediates
)
return x, intermediates
return x
class Encoder(AttentionLayers):
def __init__(self, **kwargs):
assert 'causal' not in kwargs, 'cannot set causality on encoder'
super().__init__(causal=False, **kwargs)
class TransformerWrapper(nn.Module):
def __init__(
self,
*,
num_tokens,
max_seq_len,
attn_layers,
emb_dim=None,
max_mem_len=0.,
emb_dropout=0.,
num_memory_tokens=None,
tie_embedding=False,
use_pos_emb=True
):
super().__init__()
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
dim = attn_layers.dim
emb_dim = default(emb_dim, dim)
self.max_seq_len = max_seq_len
self.max_mem_len = max_mem_len
self.num_tokens = num_tokens
self.token_emb = nn.Embedding(num_tokens, emb_dim)
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
self.emb_dropout = nn.Dropout(emb_dropout)
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
self.attn_layers = attn_layers
self.norm = nn.LayerNorm(dim)
self.init_()
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
# memory tokens (like [cls]) from Memory Transformers paper
num_memory_tokens = default(num_memory_tokens, 0)
self.num_memory_tokens = num_memory_tokens
if num_memory_tokens > 0:
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
# let funnel encoder know number of memory tokens, if specified
if hasattr(attn_layers, 'num_memory_tokens'):
attn_layers.num_memory_tokens = num_memory_tokens
def init_(self):
nn.init.normal_(self.token_emb.weight, std=0.02)
def forward(
self,
x,
return_embeddings=False,
mask=None,
return_mems=False,
return_attn=False,
mems=None,
**kwargs
):
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
x = self.token_emb(x)
x += self.pos_emb(x)
x = self.emb_dropout(x)
x = self.project_emb(x)
if num_mem > 0:
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
x = torch.cat((mem, x), dim=1)
# auto-handle masking after appending memory tokens
if exists(mask):
mask = F.pad(mask, (num_mem, 0), value=True)
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
x = self.norm(x)
mem, x = x[:, :num_mem], x[:, num_mem:]
out = self.to_logits(x) if not return_embeddings else x
if return_mems:
hiddens = intermediates.hiddens
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
return out, new_mems
if return_attn:
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
return out, attn_maps
return out

View File

@@ -0,0 +1,121 @@
# https://github.com/eladrich/pixel2style2pixel
from collections import namedtuple
import torch
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
"""
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
def l2_norm(input, axis=1):
norm = torch.norm(input, 2, axis, True)
output = torch.div(input, norm)
return output
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
""" A named tuple describing a ResNet block. """
def get_block(in_channel, depth, num_units, stride=2):
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
def get_blocks(num_layers):
if num_layers == 50:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=4),
get_block(in_channel=128, depth=256, num_units=14),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 100:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=13),
get_block(in_channel=128, depth=256, num_units=30),
get_block(in_channel=256, depth=512, num_units=3)
]
elif num_layers == 152:
blocks = [
get_block(in_channel=64, depth=64, num_units=3),
get_block(in_channel=64, depth=128, num_units=8),
get_block(in_channel=128, depth=256, num_units=36),
get_block(in_channel=256, depth=512, num_units=3)
]
else:
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
return blocks
class SEModule(Module):
def __init__(self, channels, reduction):
super(SEModule, self).__init__()
self.avg_pool = AdaptiveAvgPool2d(1)
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
self.relu = ReLU(inplace=True)
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x
class bottleneck_IR(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut
class bottleneck_IR_SE(Module):
def __init__(self, in_channel, depth, stride):
super(bottleneck_IR_SE, self).__init__()
if in_channel == depth:
self.shortcut_layer = MaxPool2d(1, stride)
else:
self.shortcut_layer = Sequential(
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
BatchNorm2d(depth)
)
self.res_layer = Sequential(
BatchNorm2d(in_channel),
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
PReLU(depth),
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
BatchNorm2d(depth),
SEModule(depth, 16)
)
def forward(self, x):
shortcut = self.shortcut_layer(x)
res = self.res_layer(x)
return res + shortcut

View File

@@ -0,0 +1,23 @@
# https://github.com/eladrich/pixel2style2pixel
import torch
from torch import nn
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone
class IDFeatures(nn.Module):
def __init__(self, model_path):
super(IDFeatures, self).__init__()
print('Loading ResNet ArcFace')
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
self.facenet.eval()
def forward(self, x, crop=False):
# Not sure of the image range here
if crop:
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
x = x[:, :, 35:223, 32:220]
x = self.face_pool(x)
x_feats = self.facenet(x)
return x_feats

View File

@@ -0,0 +1,86 @@
# https://github.com/eladrich/pixel2style2pixel
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
"""
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
"""
class Backbone(Module):
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
super(Backbone, self).__init__()
assert input_size in [112, 224], "input_size should be 112 or 224"
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
blocks = get_blocks(num_layers)
if mode == 'ir':
unit_module = bottleneck_IR
elif mode == 'ir_se':
unit_module = bottleneck_IR_SE
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
BatchNorm2d(64),
PReLU(64))
if input_size == 112:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 7 * 7, 512),
BatchNorm1d(512, affine=affine))
else:
self.output_layer = Sequential(BatchNorm2d(512),
Dropout(drop_ratio),
Flatten(),
Linear(512 * 14 * 14, 512),
BatchNorm1d(512, affine=affine))
modules = []
for block in blocks:
for bottleneck in block:
modules.append(unit_module(bottleneck.in_channel,
bottleneck.depth,
bottleneck.stride))
self.body = Sequential(*modules)
def forward(self, x):
x = self.input_layer(x)
x = self.body(x)
x = self.output_layer(x)
return l2_norm(x)
def IR_50(input_size):
"""Constructs a ir-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_101(input_size):
"""Constructs a ir-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_152(input_size):
"""Constructs a ir-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
return model
def IR_SE_50(input_size):
"""Constructs a ir_se-50 model."""
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_101(input_size):
"""Constructs a ir_se-101 model."""
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
return model
def IR_SE_152(input_size):
"""Constructs a ir_se-152 model."""
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
return model

View File

@@ -0,0 +1,276 @@
import importlib
import torchvision
import torch
from torch import optim
import numpy as np
from inspect import isfunction
from PIL import Image, ImageDraw, ImageFont
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import torch
import time
import cv2
import PIL
def pil_rectangle_crop(im):
width, height = im.size # Get dimensions
if width <= height:
left = 0
right = width
top = (height - width)/2
bottom = (height + width)/2
else:
top = 0
bottom = height
left = (width - height) / 2
bottom = (width + height) / 2
# Crop the center of the image
im = im.crop((left, top, right, bottom))
return im
def add_margin(pil_img, color=0, size=256):
width, height = pil_img.size
result = Image.new(pil_img.mode, (size, size), color)
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
return result
def create_carvekit_interface():
from carvekit.api.high import HiInterface
# Check doc strings for more information
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
batch_size_seg=5,
batch_size_matting=1,
device='cuda' if torch.cuda.is_available() else 'cpu',
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
matting_mask_size=2048,
trimap_prob_threshold=231,
trimap_dilation=30,
trimap_erosion_iters=5,
fp16=False)
return interface
def load_and_preprocess(interface, input_im):
'''
:param input_im (PIL Image).
:return image (H, W, 3) array in [0, 1].
'''
# See https://github.com/Ir1d/image-background-remove-tool
image = input_im.convert('RGB')
image_without_background = interface([image])[0]
image_without_background = np.array(image_without_background)
est_seg = image_without_background > 127
image = np.array(image)
foreground = est_seg[:, : , -1].astype(np.bool_)
image[~foreground] = [255., 255., 255.]
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
image = image[y:y+h, x:x+w, :]
image = PIL.Image.fromarray(np.array(image))
# resize image such that long edge is 512
image.thumbnail([200, 200], Image.LANCZOS)
image = add_margin(image, (255, 255, 255), size=256)
image = np.array(image)
return image
def log_txt_as_img(wh, xc, size=10):
# wh a tuple of (width, height)
# xc a list of captions to plot
b = len(xc)
txts = list()
for bi in range(b):
txt = Image.new("RGB", wh, color="white")
draw = ImageDraw.Draw(txt)
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
nc = int(40 * (wh[0] / 256))
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
try:
draw.text((0, 0), lines, fill="black", font=font)
except UnicodeEncodeError:
print("Cant encode string for logging. Skipping.")
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
txts.append(txt)
txts = np.stack(txts)
txts = torch.tensor(txts)
return txts
def ismap(x):
if not isinstance(x, torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] > 3)
def isimage(x):
if not isinstance(x,torch.Tensor):
return False
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def mean_flat(tensor):
"""
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
Take the mean over all non-batch dimensions.
"""
return tensor.mean(dim=list(range(1, len(tensor.shape))))
def count_params(model, verbose=False):
total_params = sum(p.numel() for p in model.parameters())
if verbose:
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
return total_params
def instantiate_from_config(config):
if not "target" in config:
if config == '__is_first_stage__':
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
print(module)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
class AdamWwithEMAandWings(optim.Optimizer):
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
ema_power=1., param_names=()):
"""AdamW that saves EMA versions of the parameters."""
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= ema_decay <= 1.0:
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
ema_power=ema_power, param_names=param_names)
super().__init__(params, defaults)
def __setstate__(self, state):
super().__setstate__(state)
for group in self.param_groups:
group.setdefault('amsgrad', False)
@torch.no_grad()
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
grads = []
exp_avgs = []
exp_avg_sqs = []
ema_params_with_grad = []
state_sums = []
max_exp_avg_sqs = []
state_steps = []
amsgrad = group['amsgrad']
beta1, beta2 = group['betas']
ema_decay = group['ema_decay']
ema_power = group['ema_power']
for p in group['params']:
if p.grad is None:
continue
params_with_grad.append(p)
if p.grad.is_sparse:
raise RuntimeError('AdamW does not support sparse gradients')
grads.append(p.grad)
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
if amsgrad:
# Maintains max of all exp. moving avg. of sq. grad. values
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
# Exponential moving average of parameter values
state['param_exp_avg'] = p.detach().float().clone()
exp_avgs.append(state['exp_avg'])
exp_avg_sqs.append(state['exp_avg_sq'])
ema_params_with_grad.append(state['param_exp_avg'])
if amsgrad:
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
# update the steps for each param group update
state['step'] += 1
# record the step after step update
state_steps.append(state['step'])
optim._functional.adamw(params_with_grad,
grads,
exp_avgs,
exp_avg_sqs,
max_exp_avg_sqs,
state_steps,
amsgrad=amsgrad,
beta1=beta1,
beta2=beta2,
lr=group['lr'],
weight_decay=group['weight_decay'],
eps=group['eps'],
maximize=False)
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
return loss

View File

@@ -28,6 +28,8 @@ class DocumentGroundedDialogRetrievalModel(TorchModel):
map_location='cpu')
compatible_position_ids(state_dict,
'ctx_encoder.encoder.embeddings.position_ids')
compatible_position_ids(state_dict,
'qry_encoder.encoder.embeddings.position_ids')
self.model.load_state_dict(state_dict)
def forward(self, input: Dict[str, Tensor], gck_segment=32):

View File

@@ -49,6 +49,7 @@ class MsModelMixin:
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
"""
model_dir = kwargs.pop('model_dir', None)
device = kwargs.pop('device', None)
if model_dir is None:
config = LlamaConfig(**kwargs)
model = cls(config)
@@ -56,7 +57,8 @@ class MsModelMixin:
model = super(MsModelMixin, cls).from_pretrained(
pretrained_model_name_or_path=model_dir, **kwargs)
model.model_dir = model_dir
return model
return model if 'device_map' in kwargs \
or device is None else model.to(device)
class LlamaPreTrainedModel(MsModelMixin, LlamaPreTrainedModelHF, TorchModel):

View File

@@ -375,6 +375,8 @@ class ProtoNet(nn.Module):
input_ids = torch.IntTensor(input_ids)
if not isinstance(input_mask, Tensor):
input_mask = torch.IntTensor(input_mask)
input_ids = input_ids.to(self.bert.device)
input_mask = input_mask.to(self.bert.device)
rst = self.bert(input_ids, input_mask)
last_hidden_states = rst.last_hidden_state
if len(input_mask.shape) == 2:

View File

@@ -69,6 +69,7 @@ class OutputKeys(object):
PCD12 = 'pcd12'
PCD12_ALIGN = 'pcd12_align'
TBOUNDS = 'tbounds'
MV_IMGS = 'MViews'
OutputTypes = {
@@ -132,6 +133,7 @@ OutputTypes = {
OutputKeys.PCD12: np.ndarray,
OutputKeys.PCD12_ALIGN: np.ndarray,
OutputKeys.TBOUNDS: Dict,
OutputKeys.MV_IMGS: List[np.ndarray],
}
OutputTypeSchema = {
@@ -426,6 +428,15 @@ OutputTypeSchema = {
OutputKeys.TBOUNDS: {
'type': 'object'
},
OutputKeys.MV_IMGS: {
'type': 'array',
'items': {
'type': 'array',
'items': {
'type': 'number'
}
}
},
}
TASK_OUTPUTS = {
@@ -1632,6 +1643,7 @@ TASK_OUTPUTS = {
# "output_imgs": np.ndarray list with shape [[height, width, 3], ...]
# }
Tasks.image_view_transform: [OutputKeys.OUTPUT_IMGS],
Tasks.image_to_3d: [OutputKeys.MV_IMGS]
}

View File

@@ -247,8 +247,10 @@ TASK_INPUTS = {
InputType.VIDEO,
# image generation task result for a single image
Tasks.image_to_image_generation:
InputType.IMAGE,
Tasks.image_to_image_generation: [
InputType.IMAGE,
(InputType.IMAGE, InputType.IMAGE, InputType.IMAGE, InputType.IMAGE)
],
Tasks.image_to_image_translation:
InputType.IMAGE,
Tasks.image_style_transfer: {

Some files were not shown because too many files have changed in this diff Show More