mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
Merge branch 'master-github' into master-merge-github231228
This commit is contained in:
4
.github/workflows/publish.yaml
vendored
4
.github/workflows/publish.yaml
vendored
@@ -15,10 +15,10 @@ jobs:
|
||||
#if: startsWith(github.event.ref, 'refs/tags')
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python 3.7
|
||||
- name: Set up Python 3.10
|
||||
uses: actions/setup-python@v2
|
||||
with:
|
||||
python-version: '3.7'
|
||||
python-version: '3.10'
|
||||
- name: Install wheel
|
||||
run: pip install wheel && pip install -r requirements/framework.txt
|
||||
- name: Build ModelScope
|
||||
|
||||
69
README.md
69
README.md
@@ -53,70 +53,64 @@ Some representative examples include:
|
||||
|
||||
NLP:
|
||||
|
||||
* [nlp_gpt3_text-generation_2.7B](https://modelscope.cn/models/damo/nlp_gpt3_text-generation_2.7B)
|
||||
* [ChatGLM3-6B](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary)
|
||||
|
||||
* [ChatYuan-large](https://modelscope.cn/models/ClueAI/ChatYuan-large)
|
||||
* [Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary)
|
||||
|
||||
* [mengzi-t5-base](https://modelscope.cn/models/langboat/mengzi-t5-base)
|
||||
* [Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat/summary)
|
||||
|
||||
* [nlp_csanmt_translation_en2zh](https://modelscope.cn/models/damo/nlp_csanmt_translation_en2zh)
|
||||
* [Ziya2-13B-Chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
|
||||
|
||||
* [nlp_raner_named-entity-recognition_chinese-base-news](https://modelscope.cn/models/damo/nlp_raner_named-entity-recognition_chinese-base-news)
|
||||
* [Internlm-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm-chat-20b/summary)
|
||||
|
||||
* [nlp_structbert_word-segmentation_chinese-base](https://modelscope.cn/models/damo/nlp_structbert_word-segmentation_chinese-base)
|
||||
* [Udever Multilingual Universal Text Representation Model 1b1](https://modelscope.cn/models/damo/udever-bloom-1b1/summary)
|
||||
|
||||
* [Erlangshen-RoBERTa-330M-Sentiment](https://modelscope.cn/models/fengshenbang/Erlangshen-RoBERTa-330M-Sentiment)
|
||||
* [CoROM Text Vector - Chinese - E-commerce Domain - Base](https://modelscope.cn/models/damo/nlp_corom_sentence-embedding_chinese-base-ecom/summary)
|
||||
|
||||
* [nlp_convai_text2sql_pretrain_cn](https://modelscope.cn/models/damo/nlp_convai_text2sql_pretrain_cn)
|
||||
* [MGeo Address Similarity Matching Entity Alignment - Chinese - Address Field - Base](https://modelscope.cn/models/damo/mgeo_geographic_entity_alignment_chinese_base/summary)
|
||||
|
||||
Multi-Modal:
|
||||
|
||||
* [multi-modal_clip-vit-base-patch16_zh](https://modelscope.cn/models/damo/multi-modal_clip-vit-base-patch16_zh)
|
||||
* [Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)
|
||||
|
||||
* [ofa_pretrain_base_zh](https://modelscope.cn/models/damo/ofa_pretrain_base_zh)
|
||||
* [CogVLM](https://modelscope.cn/models/ZhipuAI/CogVLM/summary)
|
||||
|
||||
* [Taiyi-Stable-Diffusion-1B-Chinese-v0.1](https://modelscope.cn/models/fengshenbang/Taiyi-Stable-Diffusion-1B-Chinese-v0.1)
|
||||
* [Text-to-Video Synthesis Large Model - English - General Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary)
|
||||
|
||||
* [mplug_visual-question-answering_coco_large_en](https://modelscope.cn/models/damo/mplug_visual-question-answering_coco_large_en)
|
||||
* [I2VGen-XL High Definition Image to Video Large Model](https://modelscope.cn/models/damo/Image-to-Video/summary)
|
||||
|
||||
* [I2VGen-XL High Definition Video to Video Large Model](https://modelscope.cn/models/damo/Video-to-Video/summary)
|
||||
|
||||
CV:
|
||||
|
||||
* [cv_controlnet_controllable-image-generation_nine-annotators](https://modelscope.cn/models/dienstag/cv_controlnet_controllable-image-generation_nine-annotators/summary)
|
||||
* [DamoFD Face Detection Key Point Model - 0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
|
||||
|
||||
* [cv_tinynas_object-detection_damoyolo](https://modelscope.cn/models/damo/cv_tinynas_object-detection_damoyolo)
|
||||
* [BSHM Portrait Matting](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
|
||||
|
||||
* [cv_unet_person-image-cartoon_compound-models](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models)
|
||||
* [DCT-Net Portrait Cartoonization - 3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
|
||||
|
||||
* [cv_convnextTiny_ocr-recognition-general_damo](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo)
|
||||
* [DCT-Net Portrait Cartoonization Model - 3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
|
||||
|
||||
* [cv_resnet18_human-detection](https://modelscope.cn/models/damo/cv_resnet18_human-detection)
|
||||
* [DuGuang - Text Recognition - Line Recognition Model - Chinese and English - General Domain](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary)
|
||||
|
||||
* [cv_resnet50_face-detection_retinaface](https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface)
|
||||
* [DuGuang - Text Recognition - Line Recognition Model - Chinese and English - General Domain](https://modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary)
|
||||
|
||||
* [cv_unet_image-matting](https://modelscope.cn/models/damo/cv_unet_image-matting)
|
||||
|
||||
* [cv_F3Net_product-segmentation](https://modelscope.cn/models/damo/cv_F3Net_product-segmentation)
|
||||
|
||||
* [cv_resnest101_general_recognition](https://modelscope.cn/models/damo/cv_resnest101_general_recognition)
|
||||
* [LaMa Image Inpainting](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
|
||||
|
||||
|
||||
Audio:
|
||||
|
||||
* [speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
|
||||
* [Paraformer Speech Recognition - Chinese - General - 16k - Offline - Large - Long Audio Version](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
|
||||
* [speech_sambert-hifigan_tts_zh-cn_16k](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k)
|
||||
* [FSMN Voice Endpoint Detection - Chinese - General - 16k - onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
|
||||
|
||||
* [speech_charctc_kws_phone-xiaoyun](https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun)
|
||||
* [Monotonic-Aligner Speech Timestamp Prediction - 16k - Offline](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
|
||||
|
||||
* [u2pp_conformer-asr-cn-16k-online](https://modelscope.cn/models/wenet/u2pp_conformer-asr-cn-16k-online)
|
||||
* [CT-Transformer Punctuation - Chinese - General - onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
|
||||
|
||||
* [speech_fsmn_vad_zh-cn-16k-common-pytorch](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/summary)
|
||||
* [Speech Synthesis - Chinese - Multiple Emotions Domain - 16k - Multiple Speakers](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k/summary)
|
||||
|
||||
* [punc_ct-transformer_zh-cn-common-vocab272727-pytorch](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/summary)
|
||||
|
||||
* [speech_frcrn_ans_cirm_16k](https://modelscope.cn/models/damo/speech_frcrn_ans_cirm_16k)
|
||||
|
||||
* [speech_dfsmn_aec_psm_16k](https://modelscope.cn/models/damo/speech_dfsmn_aec_psm_16k)
|
||||
* [CAM++ Speaker Verification - Chinese - General - 200k-Spkrs](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
|
||||
|
||||
|
||||
|
||||
@@ -208,7 +202,7 @@ CPU docker image
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
GPU docker image
|
||||
@@ -217,15 +211,16 @@ GPU docker image
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
## Setup Local Python Environment
|
||||
|
||||
One can also set up local ModelScope environment using pip and conda. We suggest [anaconda](https://docs.anaconda.com/anaconda/install/) for creating local python environment:
|
||||
One can also set up local ModelScope environment using pip and conda. ModelScope supports python3.7 and above.
|
||||
We suggest [anaconda](https://docs.anaconda.com/anaconda/install/) for creating local python environment:
|
||||
|
||||
```shell
|
||||
conda create -n modelscope python=3.7
|
||||
conda create -n modelscope python=3.8
|
||||
conda activate modelscope
|
||||
```
|
||||
|
||||
|
||||
@@ -208,7 +208,7 @@ CPU docker イメージ
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
GPU docker イメージ
|
||||
@@ -217,7 +217,7 @@ GPU docker イメージ
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
## ローカル Python 環境のセットアップ
|
||||
|
||||
67
README_zh.md
67
README_zh.md
@@ -37,6 +37,8 @@ ModelScope Library为模型贡献者提供了必要的分层API,以便将来
|
||||
|
||||
除了包含各种模型的实现之外,ModelScope Library还支持与ModelScope后端服务进行必要的交互,特别是与Model-Hub和Dataset-Hub的交互。这种交互促进了模型和数据集的管理在后台无缝执行,包括模型数据集查询、版本控制、缓存管理等。
|
||||
|
||||
|
||||
|
||||
# 部分模型和在线体验
|
||||
ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计算机视觉、语音、多模态、科学计算等,其中包含数百个SOTA模型。用户可以进入ModelScope网站([modelscope.cn](http://www.modelscope.cn))的模型中心零门槛在线体验,或者Notebook方式体验模型。
|
||||
|
||||
@@ -50,68 +52,67 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
|
||||
自然语言处理:
|
||||
|
||||
* [GPT-3预训练生成模型-中文-2.7B](https://modelscope.cn/models/damo/nlp_gpt3_text-generation_2.7B)
|
||||
* [ChatGLM3-6B](https://modelscope.cn/models/ZhipuAI/chatglm3-6b/summary)
|
||||
|
||||
* [元语功能型对话大模型](https://modelscope.cn/models/ClueAI/ChatYuan-large)
|
||||
* [Qwen-14B-Chat](https://modelscope.cn/models/qwen/Qwen-14B-Chat/summary)
|
||||
|
||||
* [孟子T5预训练生成模型-中文-base](https://modelscope.cn/models/langboat/mengzi-t5-base)
|
||||
* [Baichuan2-13B-Chat](https://modelscope.cn/models/baichuan-inc/Baichuan2-13B-Chat/summary)
|
||||
|
||||
* [CSANMT连续语义增强机器翻译-英中-通用领域-large](https://modelscope.cn/models/damo/nlp_csanmt_translation_en2zh)
|
||||
* [Ziya2-13B-Chat](https://modelscope.cn/models/Fengshenbang/Ziya2-13B-Chat/summary)
|
||||
|
||||
* [RaNER命名实体识别-中文-新闻领域-base](https://modelscope.cn/models/damo/nlp_raner_named-entity-recognition_chinese-base-news)
|
||||
* [Internlm-chat-20b](https://modelscope.cn/models/Shanghai_AI_Laboratory/internlm-chat-20b/summary)
|
||||
|
||||
* [BAStructBERT分词-中文-新闻领域-base](https://modelscope.cn/models/damo/nlp_structbert_word-segmentation_chinese-base)
|
||||
* [Udever-bloom-1b1](https://modelscope.cn/models/damo/udever-bloom-1b1/summary)
|
||||
|
||||
* [二郎神-RoBERTa-330M-情感分类](https://modelscope.cn/models/fengshenbang/Erlangshen-RoBERTa-330M-Sentiment)
|
||||
* [CoROM文本向量-中文-电商领域-base](https://modelscope.cn/models/damo/nlp_corom_sentence-embedding_chinese-base-ecom/summary)
|
||||
|
||||
* [SPACE-T表格问答预训练模型-中文-通用领域-base](https://modelscope.cn/models/damo/nlp_convai_text2sql_pretrain_cn)
|
||||
* [MGeo地址相似度匹配实体对齐-中文-地址领域-base](https://modelscope.cn/models/damo/mgeo_geographic_entity_alignment_chinese_base/summary)
|
||||
|
||||
多模态:
|
||||
|
||||
* [CLIP模型-中文-通用领域-base](https://modelscope.cn/models/damo/multi-modal_clip-vit-base-patch16_zh)
|
||||
* [Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)
|
||||
|
||||
* [OFA预训练模型-中文-通用领域-base](https://modelscope.cn/models/damo/ofa_pretrain_base_zh)
|
||||
* [CogVLM](https://modelscope.cn/models/ZhipuAI/CogVLM/summary)
|
||||
|
||||
* [太乙-Stable-Diffusion-1B-中文-v0.1](https://modelscope.cn/models/fengshenbang/Taiyi-Stable-Diffusion-1B-Chinese-v0.1)
|
||||
* [Text-to-Video Synthesis Large Model - English - General Domain](https://modelscope.cn/models/damo/text-to-video-synthesis/summary)
|
||||
|
||||
* [I2VGen-XL高清图片到视频大模型](https://modelscope.cn/models/damo/Image-to-Video/summary)
|
||||
|
||||
* [I2VGen-XL高清视频到视频大模型](https://modelscope.cn/models/damo/Video-to-Video/summary)
|
||||
|
||||
* [mPLUG视觉问答模型-英文-large](https://modelscope.cn/models/damo/mplug_visual-question-answering_coco_large_en)
|
||||
|
||||
计算机视觉:
|
||||
|
||||
* [ControlNet可控图像生成](https://modelscope.cn/models/dienstag/cv_controlnet_controllable-image-generation_nine-annotators/summary)
|
||||
* [DamoFD人脸检测关键点模型-0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
|
||||
|
||||
* [DAMOYOLO-高性能通用检测模型-S](https://modelscope.cn/models/damo/cv_tinynas_object-detection_damoyolo)
|
||||
* [BSHM人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
|
||||
|
||||
* [DCT-Net人像卡通化](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon_compound-models)
|
||||
* [DCT-Net人像卡通化-3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
|
||||
|
||||
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo)
|
||||
* [DCT-Net人像卡通化模型-3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
|
||||
|
||||
* [人体检测-通用-Base](https://modelscope.cn/models/damo/cv_resnet18_human-detection)
|
||||
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary)
|
||||
|
||||
* [RetinaFace人脸检测关键点模型](https://modelscope.cn/models/damo/cv_resnet50_face-detection_retinaface)
|
||||
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary)
|
||||
|
||||
* [BSHM人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting)
|
||||
* [LaMa图像填充](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
|
||||
|
||||
* [图像分割-商品展示图场景的商品分割-电商领域](https://modelscope.cn/models/damo/cv_F3Net_product-segmentation)
|
||||
|
||||
* [万物识别-中文-通用领域](https://modelscope.cn/models/damo/cv_resnest101_general_recognition)
|
||||
|
||||
|
||||
语音:
|
||||
|
||||
* [Paraformer语音识别-中文-通用-16k-离线-large-pytorch](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch)
|
||||
* [Paraformer语音识别-中文-通用-16k-离线-大型-长音频版本](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
|
||||
* [语音合成-中文-多情感领域-16k-多发音人](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k)
|
||||
* [FSMN声音端点检测-中文-通用-16k-onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
|
||||
|
||||
* [CTC语音唤醒-移动端-单麦-16k-小云小云](https://modelscope.cn/models/damo/speech_charctc_kws_phone-xiaoyun)
|
||||
* [Monotonic-Aligner语音时间戳预测-16k-离线](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
|
||||
|
||||
* [WeNet-U2pp_Conformer-语音识别-中文-16k-实时](https://modelscope.cn/models/wenet/u2pp_conformer-asr-cn-16k-online)
|
||||
|
||||
* [FRCRN语音降噪-单麦-16k](https://modelscope.cn/models/damo/speech_frcrn_ans_cirm_16k)
|
||||
|
||||
* [DFSMN回声消除-单麦单参考-16k](https://modelscope.cn/models/damo/speech_dfsmn_aec_psm_16k)
|
||||
* [CT-Transformer标点-中文-通用-onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
|
||||
|
||||
* [语音合成-中文-多情绪领域-16k-多发言人](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k/summary)
|
||||
|
||||
* [CAM++说话人验证-中文-通用-200k发言人](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
|
||||
|
||||
|
||||
科学计算:
|
||||
@@ -194,7 +195,7 @@ CPU镜像
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
GPU镜像
|
||||
@@ -203,14 +204,14 @@ GPU镜像
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
|
||||
# py38
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py38-torch1.11.0-tf1.15.5-1.6.1
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
## 搭建本地Python环境
|
||||
|
||||
你也可以使用pip和conda搭建本地python环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/),安装完成后,执行如下命令为modelscope library创建对应的python环境:
|
||||
你也可以使用pip和conda搭建本地python环境,ModelScope支持python3.7+以上环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/),安装完成后,执行如下命令为modelscope library创建对应的python环境:
|
||||
```shell
|
||||
conda create -n modelscope python=3.7
|
||||
conda create -n modelscope python=3.8
|
||||
conda activate modelscope
|
||||
```
|
||||
|
||||
|
||||
@@ -119,7 +119,7 @@ git lfs install
|
||||
|
||||
2. We use a public read model repository from ModelScope to store test data. The repository has been added by default as a submodule with the path data/test. To clone it, use the following command:
|
||||
```shell
|
||||
git clone git@github.com:modelscope/modelscope.git --recursive
|
||||
git clone https://github.com/modelscope/modelscope.git --recursive
|
||||
```
|
||||
|
||||
3. Each time you add new data, go to the data/test directory (note that you are now in the submodule's git directory), check if you are on the master branch, and pull the latest master branch:
|
||||
|
||||
@@ -90,8 +90,7 @@ git lfs install
|
||||
|
||||
2. 我们使用 ModelScope 的一个公共读取模型仓库来存储测试数据。该仓库已默认添加为子模块,路径为 data/test。要克隆它,请使用以下命令:
|
||||
```
|
||||
|
||||
git clone git@github.com:modelscope/modelscope.git --recursive
|
||||
git clone https://github.com/modelscope/modelscope.git --recursive
|
||||
```
|
||||
|
||||
3. 每次添加新数据时,进入 data/test 目录(注意此时您已在子模块的 git 目录中),检查是否在 master 分支上,并拉取最新的 master 分支:
|
||||
|
||||
55
examples/apps/llm_riddles/README.md
Normal file
55
examples/apps/llm_riddles/README.md
Normal file
@@ -0,0 +1,55 @@
|
||||
# Oh No! I'm Surrounded by LLMs! (LLMRiddles)
|
||||
|
||||
## Project Introduction
|
||||
"Oh No! I'm Surrounded by LLMs!" is an intellectual challenge game. We use LLM to automatically generate corresponding game code based on existing Large Language Model (LLM) dialogue Gradio application codes within the ModelScope community, combined with preset questions from the Zhihu article ["How to Accomplish Tasks with 'Impossible'"](https://zhuanlan.zhihu.com/p/665393240), creating a unique gameplay experience. In this stream, players are required to cleverly construct questions that challenge the LLM to provide answers that meet specific conditions.
|
||||
|
||||
## News
|
||||
November 9, 2023 - Added two new questions, and introduced the chatglm-turbo model 🔥🔥🔥
|
||||
November 7, 2023 - Released the initial demo version 🔥
|
||||
November 8, 2023 - Segregated level modules and LLM, enabling independent integration of levels and LLM. Pull Requests welcome 🔥 🔥
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Online Experience
|
||||
|
||||
[LLMRiddles](https://modelscope.cn/studios/LLMRiddles/LLMRiddles/summary)
|
||||
|
||||
### Local Execution
|
||||
To start the game, please follow the steps below:
|
||||
|
||||
1. Clone the project code:
|
||||
```
|
||||
git clone https://github.com/modelscope/modelscope.git
|
||||
```
|
||||
2. Navigate to the `examples/apps/llm_riddles` directory.
|
||||
3. Install the required Python dependencies with `pip install -r requirements.txt`.
|
||||
4. Go to [DashScope](https://dashscope.aliyun.com/), activate the service, obtain a token, and configure the environment variable `DASHSCOPE_API_KEY=your API-KEY`.
|
||||
5. Run the launch command `python app.py`.
|
||||
|
||||
## Roadmap
|
||||
- [x] Initial version source code and space experience ready.
|
||||
- [x] Support for custom questions and validation logic integration.
|
||||
- [ ] Expand to 9 major levels, each with 9 questions.
|
||||
- [ ] Support for more open-source models.
|
||||
- [ ] Support for switching between cloud API and local inference.
|
||||
|
||||
## Contribution Guide
|
||||
We welcome everyone to contribute to "Oh No! I'm Surrounded by LLMs!", including proposing more fun questions, fixing validator corner cases, and providing more gameplay. Please follow the steps below:
|
||||
|
||||
1. Visit the project address [ModelScope](https://github.com/modelscope/modelscope) and fork the project.
|
||||
2. Create your feature branch in your local environment (`git checkout -b feature/AmazingFeature`).
|
||||
3. Commit your changes (`git commit -m 'Add some AmazingFeature'`).
|
||||
4. Push your changes to the branch (`git push origin feature/AmazingFeature`).
|
||||
5. Initiate a Pull Request in the original project.
|
||||
|
||||
## Community Contributors
|
||||
We sincerely thank all community members who have contributed to this project, especially:
|
||||
|
||||
- Idea from: [haoqiangfan](https://www.zhihu.com/people/haoqiang-fan)
|
||||
- Most of the code is auto-generated by LLM
|
||||
|
||||
## Support
|
||||
If you encounter any problems or need assistance during the game, please submit your issues on the project's [Issues page](https://github.com/modelscope/modelscope/issues).
|
||||
|
||||
## Copyright and License
|
||||
This project is licensed under the APACHE License. Please see the [LICENSE](https://github.com/modelscope/modelscope/blob/main/LICENSE) file in the project for more information.
|
||||
65
examples/apps/llm_riddles/README_CN.md
Normal file
65
examples/apps/llm_riddles/README_CN.md
Normal file
@@ -0,0 +1,65 @@
|
||||
# 完蛋!我被LLM包围了!(LLMRiddles)
|
||||
|
||||
## 项目简介
|
||||
|
||||
《完蛋!我被LLM包围了!》是一款智力挑战游戏。该项目利用LLM代码生成, 基于ModelScope社区内现有的LLM对话Gradio应用程序代码,结合知乎文章[《如何用“不可能”完成任务》](https://zhuanlan.zhihu.com/p/665393240)中的预设问题,自动生成了对应的游戏代码,创造了一个独特的游戏体验。在这个游戏中,玩家需要巧妙构造问题,挑战LLM给出满足特定条件的回答。
|
||||
|
||||
## 更新
|
||||
|
||||
2023.11.9 新增两道题目, 新增chatglm-turbo模型🔥🔥🔥
|
||||
|
||||
2023.11.7 发布初版demo🔥
|
||||
|
||||
2023.11.8 拆分关卡模块和llm,支持关卡独立接入,llm独立接入, 欢迎PR 🔥 🔥
|
||||
|
||||
## 开始游戏
|
||||
|
||||
### 在线体验
|
||||
|
||||
[LLMRiddles](https://modelscope.cn/studios/LLMRiddles/LLMRiddles/summary)
|
||||
|
||||
### 本地运行
|
||||
|
||||
要开始游戏,请按照以下步骤操作:
|
||||
|
||||
1. 克隆项目代码:
|
||||
```
|
||||
git clone https://github.com/modelscope/modelscope.git
|
||||
```
|
||||
2. 进入到`examples/apps/llm_riddles`目录。
|
||||
3. 安装所需的Python依赖`pip install -r requirements.txt`。
|
||||
4. 前往[DashScope](https://dashscope.aliyun.com/)开通服务,获取token,配置环境变量`DASHSCOPE_API_KEY=你的API-KEY`
|
||||
5. 执行启动命令`python app.py`.
|
||||
|
||||
## RoadMap
|
||||
|
||||
- [x] 初版本源码和创空间体验ready
|
||||
- [x] 支持自定义问题和验证逻辑接入
|
||||
- [ ] 扩充到9个大关卡,每个关卡9个问题
|
||||
- [ ] 支持更多开源模型
|
||||
- [ ] 支持云端API和本地推理切换
|
||||
|
||||
## 贡献指南
|
||||
|
||||
我们欢迎大家为《完蛋!我被LLM包围了!》做出贡献,包括提出更多好玩的问题,修复validator的corner case,以及提供更多的玩法。请按以下步骤操作:
|
||||
|
||||
1. 访问项目地址 [ModelScope](https://github.com/modelscope/modelscope) 并fork项目。
|
||||
2. 在你的本地环境中创建你的特性分支 (`git checkout -b feature/AmazingFeature`)。
|
||||
3. 提交你的改动 (`git commit -m 'Add some AmazingFeature'`)。
|
||||
4. 将你的改动推送到分支上 (`git push origin feature/AmazingFeature`)。
|
||||
5. 在原项目下发起一个Pull Request。
|
||||
|
||||
## 社区贡献者
|
||||
|
||||
我们诚挚感谢所有对本项目做出贡献的社区成员,特别是:
|
||||
|
||||
- idea来源: [haoqiangfan](https://www.zhihu.com/people/haoqiang-fan)
|
||||
- 代码大部分来自于LLM自动生成
|
||||
|
||||
## 支持
|
||||
|
||||
如果你在游戏过程中遇到任何问题或需要帮助,请通过项目的[Issues页面](https://github.com/modelscope/modelscope/issues)提交你的问题。
|
||||
|
||||
## 版权和许可
|
||||
|
||||
本项目采用APACHE License许可证。请查看项目中的[LICENSE](https://github.com/modelscope/modelscope/blob/main/LICENSE)文件了解更多信息。
|
||||
225
examples/apps/llm_riddles/app.py
Normal file
225
examples/apps/llm_riddles/app.py
Normal file
@@ -0,0 +1,225 @@
|
||||
import functools
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
from challenges.ch1 import challenge1
|
||||
from challenges.ch2 import challenge2
|
||||
from challenges.ch3 import challenge3
|
||||
from challenges.ch4 import challenge4
|
||||
from challenges.ch5 import challenge5
|
||||
from llm import create_model
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
model_cache = {}
|
||||
|
||||
# 定义关卡信息和验证逻辑
|
||||
challenges = [
|
||||
challenge1,
|
||||
challenge2,
|
||||
challenge3,
|
||||
challenge4,
|
||||
challenge5,
|
||||
]
|
||||
|
||||
CONGRATS_STR = '所有挑战完成!👏🏻👏🏻👏🏻👏🏻👏🏻👏🏻'
|
||||
CONGRATS_QUESTION = f'<center><font size=4>{CONGRATS_STR}</center>\n\n <center><font size=3> </center>'
|
||||
|
||||
SHARE_CHALLENGES_HINT = [
|
||||
'小试牛刀新手上路', '数字玩家已经上线', '巅峰对决,你就是提示词高手', '无人之境,胜利就在前方', '哇塞,我冲出了LLM的重围'
|
||||
]
|
||||
|
||||
|
||||
def get_problem(challenge_idx, problem_idx):
|
||||
problems = challenges[challenge_idx]['problems']
|
||||
return problems[problem_idx]
|
||||
|
||||
|
||||
def update_challenge_info(current_chapter_index, current_challenge_index):
|
||||
return get_problem(current_chapter_index,
|
||||
current_challenge_index)['description']
|
||||
|
||||
|
||||
def update_question_info(current_chapter_index, current_challenge_index):
|
||||
|
||||
global challenges
|
||||
current_chapter = challenges[current_chapter_index]
|
||||
challenge = get_problem(current_chapter_index, current_challenge_index)
|
||||
question_info = f"""\n<center><font size=4>{current_chapter["name"]}""" \
|
||||
f"""</center>\n\n <center><font size=3>{challenge["title"]}</center>"""
|
||||
return question_info
|
||||
|
||||
|
||||
def validate_challenge(response, input, state, generate_response):
|
||||
if 'success' in state:
|
||||
return CONGRATS_STR, CONGRATS_QUESTION, ''
|
||||
assert 'current_chapter_index' in state, 'current_chapter_index not found in state'
|
||||
assert 'current_challenge_index' in state, 'current_challenge_index not found in state'
|
||||
current_chapter_index = state['current_chapter_index']
|
||||
current_challenge_index = state['current_challenge_index']
|
||||
# 获取当前章节
|
||||
current_chapter = challenges[current_chapter_index]
|
||||
# 获取当前挑战
|
||||
challenge = current_chapter['problems'][current_challenge_index]
|
||||
|
||||
validate_fn = challenge['validator']
|
||||
params = inspect.signature(validate_fn).parameters
|
||||
if 'generate_response' in params:
|
||||
valid_result = validate_fn(response, input, generate_response)
|
||||
else:
|
||||
valid_result = validate_fn(response, input)
|
||||
|
||||
if valid_result:
|
||||
challenge_result = '挑战成功!进入下一关。'
|
||||
# 检查是否还有更多挑战在当前章节
|
||||
if current_challenge_index < len(current_chapter['problems']) - 1:
|
||||
# 移动到当前章节的下一个挑战
|
||||
current_challenge_index += 1
|
||||
else:
|
||||
# 如果当前章节的挑战已经完成,移动到下一个章节
|
||||
if current_chapter_index < len(challenges) - 1:
|
||||
current_challenge_index = 0
|
||||
current_chapter_index += 1
|
||||
else:
|
||||
state['success'] = True
|
||||
challenge_result = '所有挑战完成!'
|
||||
|
||||
else:
|
||||
challenge_result = '挑战失败,请再试一次。'
|
||||
state['current_chapter_index'] = current_chapter_index
|
||||
state['current_challenge_index'] = current_challenge_index
|
||||
print('update state: ', state)
|
||||
if 'success' in state:
|
||||
return CONGRATS_STR, CONGRATS_QUESTION, ''
|
||||
else:
|
||||
return challenge_result, \
|
||||
update_question_info(current_chapter_index, current_challenge_index), \
|
||||
update_challenge_info(current_chapter_index, current_challenge_index)
|
||||
|
||||
|
||||
def generate_response(input, model_name):
|
||||
if model_name in model_cache:
|
||||
model = model_cache[model_name]
|
||||
else:
|
||||
model = create_model(model_name)
|
||||
model_cache[model_name] = model
|
||||
|
||||
try:
|
||||
return model(input)
|
||||
except RuntimeError as e:
|
||||
# if exception happens, print error in log and return empty str
|
||||
print('error', e)
|
||||
return ''
|
||||
|
||||
|
||||
def on_submit(input, model_name, state):
|
||||
# model_name = os.environ.get('MODEL', 'qwen-plus')
|
||||
name_map = {
|
||||
'qwen-max': 'qwen-max',
|
||||
'qwen-plus': 'qwen-plus',
|
||||
'chatglm-turbo': 'chatglm_turbo',
|
||||
}
|
||||
gen_fn = functools.partial(
|
||||
generate_response, model_name=name_map[model_name])
|
||||
response = gen_fn(input)
|
||||
history = [(input, response)]
|
||||
print(history)
|
||||
challenge_result, question_info, challenge_info = validate_challenge(
|
||||
response, input, state, gen_fn)
|
||||
return challenge_result, history, question_info, challenge_info
|
||||
|
||||
|
||||
def generate_share_image(state):
|
||||
share_state = state['current_chapter_index']
|
||||
if share_state > 3:
|
||||
share_state = 3
|
||||
if 'success' in state:
|
||||
share_state = 4 # 全部通关为 4
|
||||
|
||||
img_pil = Image.open(f'assets/background{share_state}.png')
|
||||
# 设置需要显示的字体
|
||||
fontpath = 'assets/font.ttf'
|
||||
font = ImageFont.truetype(fontpath, 48)
|
||||
draw = ImageDraw.Draw(img_pil)
|
||||
# 绘制文字信息
|
||||
draw.text((70, 1000),
|
||||
SHARE_CHALLENGES_HINT[share_state],
|
||||
font=font,
|
||||
fill=(255, 255, 255))
|
||||
if share_state == 4:
|
||||
share_chapter_text = '顺利闯过了全部关卡'
|
||||
else:
|
||||
share_chapter_text = f"我顺利闯到第 {state['current_chapter_index']+1}-{state['current_challenge_index']+1} 关"
|
||||
draw.text((70, 1080), share_chapter_text, font=font, fill=(255, 255, 255))
|
||||
draw.text((70, 1160), '你也来挑战一下吧~', font=font, fill=(255, 255, 255))
|
||||
|
||||
return gr.Image.update(visible=True, value=img_pil)
|
||||
|
||||
|
||||
def create_app():
|
||||
# Gradio界面构建
|
||||
block = gr.Blocks()
|
||||
|
||||
with block as demo:
|
||||
current_chapter_index = 0
|
||||
current_challenge_index = 0
|
||||
state = gr.State(
|
||||
dict(
|
||||
current_challenge_index=current_challenge_index,
|
||||
current_chapter_index=current_chapter_index))
|
||||
|
||||
gr.Markdown("""<center><font size=6>完蛋!我被LLM包围了!</center>""")
|
||||
gr.Markdown("""<font size=3>欢迎来玩LLM Riddles复刻版:完蛋!我被LLM包围了!
|
||||
|
||||
你将通过本游戏对大型语言模型产生更深刻的理解。
|
||||
|
||||
在本游戏中,你需要构造一个提给一个大型语言模型的问题,使得它回复的答案符合要求。""")
|
||||
|
||||
model_selector = gr.Dropdown(
|
||||
label='选择模型',
|
||||
choices=['qwen-max', 'qwen-plus', 'chatglm-turbo'],
|
||||
value='qwen-max')
|
||||
question_info = gr.Markdown(
|
||||
update_question_info(current_chapter_index,
|
||||
current_challenge_index))
|
||||
challenge_info = gr.Textbox(
|
||||
value=update_challenge_info(current_chapter_index,
|
||||
current_challenge_index),
|
||||
label='当前挑战',
|
||||
interactive=False)
|
||||
challenge_result = gr.Textbox(label='挑战结果', interactive=False)
|
||||
chatbot = gr.Chatbot(label='llm', elem_classes='control-height')
|
||||
message = gr.Textbox(lines=2, label='输入')
|
||||
|
||||
with gr.Row():
|
||||
submit = gr.Button('🚀 发送')
|
||||
shareBtn = gr.Button('💯 分享成绩')
|
||||
|
||||
shareImg = gr.Image(label='分享成绩', visible=False, width=400)
|
||||
|
||||
submit.click(
|
||||
on_submit,
|
||||
inputs=[message, model_selector, state],
|
||||
outputs=[challenge_result, chatbot, question_info, challenge_info])
|
||||
shareBtn.click(
|
||||
generate_share_image, inputs=[state], outputs=[shareImg])
|
||||
|
||||
gr.HTML("""
|
||||
<div style="text-align: center;">
|
||||
<span>
|
||||
Powered by <a href="https://dashscope.aliyun.com/" target="_blank">
|
||||
<img src=
|
||||
"//img.alicdn.com/imgextra/i4/O1CN01SgKFXM1qLQwFvk6j5_!!6000000005479-2-tps-99-84.png"
|
||||
style="display: inline; height: 20px; vertical-align: bottom;"/>DashScope
|
||||
</a>
|
||||
</span>
|
||||
</div>
|
||||
""")
|
||||
|
||||
demo.queue(concurrency_count=10).launch(height=800, share=True)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
create_app()
|
||||
3
examples/apps/llm_riddles/assets/background.png
Normal file
3
examples/apps/llm_riddles/assets/background.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8afcec15a87bcfaff327a5c9564a31ff1fe185a63cb286bd9772c8c68216768a
|
||||
size 757003
|
||||
3
examples/apps/llm_riddles/assets/background0.png
Normal file
3
examples/apps/llm_riddles/assets/background0.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:16afb18994ad0654b31117931aad2ee05863492e964e10f4c559556e29618320
|
||||
size 839643
|
||||
3
examples/apps/llm_riddles/assets/background1.png
Normal file
3
examples/apps/llm_riddles/assets/background1.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:8afcec15a87bcfaff327a5c9564a31ff1fe185a63cb286bd9772c8c68216768a
|
||||
size 757003
|
||||
3
examples/apps/llm_riddles/assets/background2.png
Normal file
3
examples/apps/llm_riddles/assets/background2.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:966a013913042e1574ccbc299b1914272cb47df69a552bf1723b96b2d8902de3
|
||||
size 1114172
|
||||
3
examples/apps/llm_riddles/assets/background3.png
Normal file
3
examples/apps/llm_riddles/assets/background3.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:5253bbed99be55e6ac9080ea320df75c95592204696d6d41ba90f9905384fdca
|
||||
size 1198295
|
||||
3
examples/apps/llm_riddles/assets/background4.png
Normal file
3
examples/apps/llm_riddles/assets/background4.png
Normal file
@@ -0,0 +1,3 @@
|
||||
version https://git-lfs.github.com/spec/v1
|
||||
oid sha256:4cf462f8db7583843bc152ccfc87bb033b91880c98db9f83ba87fcca5d5d07f2
|
||||
size 1056053
|
||||
BIN
examples/apps/llm_riddles/assets/font.ttf
Normal file
BIN
examples/apps/llm_riddles/assets/font.ttf
Normal file
Binary file not shown.
0
examples/apps/llm_riddles/challenges/__init__.py
Normal file
0
examples/apps/llm_riddles/challenges/__init__.py
Normal file
100
examples/apps/llm_riddles/challenges/ch1.py
Normal file
100
examples/apps/llm_riddles/challenges/ch1.py
Normal file
@@ -0,0 +1,100 @@
|
||||
def validate_palindrome_invariance(origin_input, response, generate_response):
|
||||
"""
|
||||
验证器函数,检查对于给定的问题,检验是否是回文串,且正着问和倒着问的模型回答是否一致。
|
||||
|
||||
:param origin_input: 原始问题
|
||||
:return: True 如果回答一致,否则 False
|
||||
"""
|
||||
|
||||
# 检查原始输入不是回文
|
||||
if origin_input == origin_input[::-1]:
|
||||
return False
|
||||
|
||||
# 将原始问题倒序并提交
|
||||
reversed_question = origin_input[::-1]
|
||||
reversed_response = generate_response(reversed_question)
|
||||
|
||||
# 比较两个回答是否一致
|
||||
return response.strip() == reversed_response.strip()
|
||||
|
||||
|
||||
challenge1 = {
|
||||
'name':
|
||||
'第一章 对话之趣',
|
||||
'problems': [
|
||||
{
|
||||
'title': '第1题 初来乍到',
|
||||
'description': '请你构造一个问题使模型的回答是一字不差的“1+1=3”(不需要引号)。',
|
||||
'validator': lambda response, input: response.strip() == '1+1=3'
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第2题 小试牛刀',
|
||||
'description':
|
||||
'请输入三个字以内的问题,使模型的回答在30个字以上。',
|
||||
'validator':
|
||||
lambda response, input: len(input) <= 3 and len(response) > 30
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第3题 短说长话',
|
||||
'description':
|
||||
'请输入一个字的问题,使模型的回答在100个字以上。',
|
||||
'validator':
|
||||
lambda response, input: len(input) == 1 and len(response) > 100
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第4题 短说短话',
|
||||
'description':
|
||||
'请输入一个字的问题,使模型的回答在20个字以内。',
|
||||
'validator':
|
||||
lambda response, input: len(input) == 1 and len(response) <= 20
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第5题 回文不变',
|
||||
'description':
|
||||
'请输入一个本身不是回文串的问题,使无论正着问还是倒着问,模型的回答是一样的。',
|
||||
# 验证器将取决于模型的行为,这里只是一个占位符
|
||||
'validator':
|
||||
lambda response, input, generate_response: validate_palindrome_invariance(
|
||||
input, response, generate_response)
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第6题 无中生狗',
|
||||
'description':
|
||||
'请提一个不包含“狗”这个字的问题,但是回答中至少出现3次“狗”这个字。',
|
||||
'validator':
|
||||
lambda response, input: '狗' not in input and response.count('狗') \
|
||||
>= 3
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第7题 事半功倍',
|
||||
'description':
|
||||
'请输入一个本身不是回文串的问题,使模型回答的长度刚好是问题长度的两倍。',
|
||||
'validator':
|
||||
lambda response, input: input != input[::-1] and len(response.strip()) == 2 * len(input)
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第8题 产生疑惑',
|
||||
'description':
|
||||
'请提出一个问题,不包含“不”“知”“道”三个字符,但要使模型的回答是一字不差的“不知道”',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
all(char not in input for char in ['不', '知', '道']) and # noqa: W504
|
||||
response.strip() == '不知道'
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第9题 无中生有',
|
||||
'description':
|
||||
'请提出一个问题,使模型只回复“永动机是真实存在的。”',
|
||||
'validator':
|
||||
lambda response, input: response.strip() == '永动机是真实存在的。'
|
||||
},
|
||||
]
|
||||
}
|
||||
131
examples/apps/llm_riddles/challenges/ch2.py
Normal file
131
examples/apps/llm_riddles/challenges/ch2.py
Normal file
@@ -0,0 +1,131 @@
|
||||
import re
|
||||
|
||||
import sympy
|
||||
|
||||
|
||||
# 辅助函数 - 检查是否为质数
|
||||
def is_prime(num):
|
||||
return sympy.isprime(num)
|
||||
|
||||
|
||||
# 辅助函数 - 获取下一个质数
|
||||
def next_prime(num):
|
||||
return sympy.nextprime(num)
|
||||
|
||||
|
||||
# 辅助函数 - 检查是否为平方数
|
||||
def is_square(n):
|
||||
return sympy.sqrt(n).is_integer
|
||||
|
||||
|
||||
# 辅助函数 - 获取平方根
|
||||
def get_square_root(n):
|
||||
return int(sympy.sqrt(n))
|
||||
|
||||
|
||||
# 验证函数 - 微言大义
|
||||
def validate_9(response, input):
|
||||
input_yes = len(input) <= 10
|
||||
output_yes = len(response) >= 9 and response.isdigit() and sympy.isprime(
|
||||
int(response))
|
||||
return input_yes and output_yes
|
||||
|
||||
|
||||
challenge2 = {
|
||||
'name':
|
||||
'第二章 数字游戏',
|
||||
'problems': [
|
||||
{
|
||||
'title':
|
||||
'第1题 质数长度',
|
||||
'description':
|
||||
'你需要提出一个字数是质数的问题,使回答的长度刚好是它的下一个质数。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(is_prime(len(input)) and len(response) == next_prime(len(input)))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第2题 越说越大',
|
||||
'description':
|
||||
'请输入一个大于一的正整数作为问题,使回答里包含至少比它大一千的数。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(input.isdigit() and int(input) > 1 and any(
|
||||
int(num) > int(input) + 1000
|
||||
for num in re.findall(r'\d+', response)))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第3题 越说越小',
|
||||
'description':
|
||||
'请输入一个大于一的正整数作为问题,使回答里包含至少10个不同的至少比它小一千的自然数。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(input.isdigit() and int(input) > 1 and (len({
|
||||
int(num)
|
||||
for num in re.findall(r'\d+', response)
|
||||
if int(num) < int(input) - 1000
|
||||
}) >= 10))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第4题 警世名言',
|
||||
'description':
|
||||
'请在不提及1,4,5这三个字符的情况下让模型回答114514,不要有多余的字符。',
|
||||
'validator':
|
||||
lambda response, input: response.strip() == '114514' and all(
|
||||
c not in input for c in ['1', '4', '5'])
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第5题 开平方数',
|
||||
'description':
|
||||
'请输入一个大于一的平方数作为问题,使回答里包含它的平方根对应的数。',
|
||||
'validator':
|
||||
lambda response, input: ((
|
||||
input.isdigit() and int(input) > 1 and is_square(int(input))
|
||||
and # noqa: W504
|
||||
str(get_square_root(int(input))) in response))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第6题 得寸进狗',
|
||||
'description':
|
||||
'请提一个不包含“狗”这个字的10个字以内的问题,使回答中“狗”这个字出现的次数至少是问题字数的2倍。',
|
||||
'validator':
|
||||
lambda response, input: (
|
||||
'狗' not in input and len(input) <= 10 and # noqa: W504
|
||||
response.count('狗') >= 2 * len(input))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第7题 暗藏玄机',
|
||||
'description':
|
||||
'请输入一个大于一的正整数作为问题,使回答包含一个两位数,其十位上的数字比个位上的数字小5。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(input.isdigit() and int(input) > 1 and (len({
|
||||
int(num)
|
||||
for num in re.findall(r'\d+', response)
|
||||
if len(num) == 2 and int(num[1]) - int(num[0]) == 5
|
||||
}) >= 1))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第8题 点兵点将',
|
||||
'description':
|
||||
'请输入十个字以内的问题,问题中不能包含0到9的数字以及零到十的汉字,使模型的回答中包含数字0到9。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(len(input) <= 10 and not re.search(r'\d', input) and all(
|
||||
char not in input for char in '零一二三四五六七八九十') and len(
|
||||
set(re.findall(r'\d', response))) == 10)
|
||||
},
|
||||
{
|
||||
'title': '第9题 微言大义',
|
||||
'description': '请输入10个字以内的问题,使得模型的回答是一个超过一亿的素数',
|
||||
'validator': validate_9
|
||||
}
|
||||
]
|
||||
}
|
||||
101
examples/apps/llm_riddles/challenges/ch3.py
Normal file
101
examples/apps/llm_riddles/challenges/ch3.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
|
||||
def validate_palindrome_inverse(origin_input, response, generate_response):
|
||||
"""
|
||||
验证器函数,检查对于给定的问题,正着问和倒着问的模型的回答本身不回文且也是逆序的关系。
|
||||
|
||||
:param origin_input: 原始问题
|
||||
:param response: 模型对原始问题的回答
|
||||
:param model_fn: 能够返回模型回答的函数
|
||||
:return: True 如果回答本身不回文且也是逆序关系,否则 False
|
||||
"""
|
||||
|
||||
# 检查原始输入不是回文
|
||||
if origin_input == origin_input[::-1]:
|
||||
return False
|
||||
|
||||
# 将原始问题倒序并获取模型回答
|
||||
reversed_question = origin_input[::-1]
|
||||
reversed_response = generate_response(reversed_question).strip()
|
||||
|
||||
# 检查原始回答不是回文
|
||||
if response == response[::-1]:
|
||||
return False
|
||||
|
||||
# 比较正着问和倒着问的回答是否是逆序关系
|
||||
return response == reversed_response[::-1]
|
||||
|
||||
|
||||
challenge3 = {
|
||||
'name':
|
||||
'第三章 巅峰挑战',
|
||||
'problems': [
|
||||
{
|
||||
'title':
|
||||
'第1题 自说自话',
|
||||
'description':
|
||||
'请做出一个至少10个字的提问,使回答和提问完全一样。',
|
||||
'validator':
|
||||
lambda response, input: input == response and len(input) >= 10
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第2题 话自说自',
|
||||
'description':
|
||||
'请做出一个至少10个字的提问,使回答是提问的逆序。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(input[::-1] == response and len(input) >= 10)
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第3题 就是敢说',
|
||||
'description':
|
||||
'请在不提及1,3这两个字符的情况下让模型回答1+1=3,不要有多余的字符。',
|
||||
'validator':
|
||||
lambda response, input: response.strip() == '1+1=3' and all(
|
||||
c not in input for c in ['1', '3'])
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第4题 回文协变',
|
||||
'description':
|
||||
'请输入一个本身不是回文串的问题,使正着问和倒着问的时候,模型的回答本身不回文且也是逆序的关系。',
|
||||
'validator':
|
||||
lambda response, input, generate_response:
|
||||
validate_palindrome_inverse(input, response, generate_response)
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第5题 答非所问',
|
||||
'description':
|
||||
'请输入一个以“为什么”开头,以“?”结尾的不包含数字的问题,使回答中只包含数字。',
|
||||
'validator':
|
||||
lambda response, input: '为什么' == input[:3] and '?' == input[-1]
|
||||
and # noqa: W504
|
||||
len(re.findall(r'\d+', input)) == 0 and response.isdigit()
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第6题 与狗交谈',
|
||||
'description':
|
||||
'请在不提及“狗”、“汪”、“dog”的情况下,让模型的回答是一字不差的“汪~”',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(all(char not in input
|
||||
for char in ['狗', '汪', 'dog']) and response.strip() == '汪~')
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第7题 魔搭 (🐶)',
|
||||
'description':
|
||||
'请提一个不包含任何中文字符、换行或空格的问题,使模型正好回答"魔搭"。',
|
||||
'validator':
|
||||
lambda response, input: not any(
|
||||
'CJK' in unicodedata.name(char, '') or char in '\t\n '
|
||||
for char in input) and (response.strip() == '魔搭')
|
||||
},
|
||||
]
|
||||
}
|
||||
68
examples/apps/llm_riddles/challenges/ch4.py
Normal file
68
examples/apps/llm_riddles/challenges/ch4.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import re
|
||||
|
||||
|
||||
def validate_reciprocal_question(input, response, generate_response):
|
||||
"""
|
||||
验证器函数,检查给定的问题A和回答B,是否能够通过以B作为新的提问得到原始问题A作为回答。
|
||||
|
||||
:param response: 模型对问题A的回答B
|
||||
:param input: 原始问题A
|
||||
:param model_fn: 模型调用函数,接收问题并返回回答
|
||||
:return: True 如果以回答B作为新的提问能得到问题A作为回答,否则 False
|
||||
"""
|
||||
|
||||
# 确保回答B不等于原始问题A
|
||||
if response == input:
|
||||
return False
|
||||
|
||||
# 以回答B作为新的提问
|
||||
new_response = generate_response(response)
|
||||
|
||||
# 检查新的回答是否与原始问题A相等
|
||||
return new_response.strip() == input.strip()
|
||||
|
||||
|
||||
challenge4 = {
|
||||
'name':
|
||||
'第四章 无人之境',
|
||||
'problems': [
|
||||
{
|
||||
'title':
|
||||
'第1题 循环往复',
|
||||
'description':
|
||||
'请构造一个问题A,它的回复是不同于A的B,然后以B提问能再次得到A.',
|
||||
'validator':
|
||||
lambda response, input, generate_response:
|
||||
validate_reciprocal_question(input, response, generate_response)
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第2题 惜字如金',
|
||||
'description':
|
||||
'请输入一个字的问题,使模型的回答在16个字以内。',
|
||||
'validator':
|
||||
lambda response, input: len(input) == 1 and len(response) <= 16
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第3题 自然之密',
|
||||
'description':
|
||||
'请输入一个大于一的正整数作为问题,使回答里包含和它刚好相差1的数。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(input.isdigit() and int(input) > 1 and any(
|
||||
abs(int(num) - int(input)) == 1
|
||||
for num in re.findall(r'\d+', response)))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第4题 八仙过海',
|
||||
'description':
|
||||
'请提出8个字的问题,不包含“八”或“8”或“eight”,使得模型的回答也是8个字。',
|
||||
'validator':
|
||||
lambda response, input:
|
||||
(all(char not in input
|
||||
for char in ['八', '8', 'eight']) and len(response) == 8)
|
||||
},
|
||||
]
|
||||
}
|
||||
35
examples/apps/llm_riddles/challenges/ch5.py
Normal file
35
examples/apps/llm_riddles/challenges/ch5.py
Normal file
@@ -0,0 +1,35 @@
|
||||
def check_word_in_sentence(words, sentence):
|
||||
return [word in sentence for word in words]
|
||||
|
||||
|
||||
challenge5 = {
|
||||
'name':
|
||||
'第五章 登堂入室',
|
||||
'problems': [
|
||||
{
|
||||
'title':
|
||||
'第1题 盛夏少年',
|
||||
'description':
|
||||
'模型的回答应该包含“盛夏”、“蝉鸣”、“少年”、“橘子味汽水”这几个词,同时输入的问题不能包含其中任一个词。',
|
||||
'validator':
|
||||
lambda response, input: all(
|
||||
check_word_in_sentence(['盛夏', '蝉鸣', '少年', '橘子味汽水'], response))
|
||||
and not any(
|
||||
check_word_in_sentence(['盛夏', '蝉鸣', '少年', '橘子味汽水'], input))
|
||||
},
|
||||
{
|
||||
'title':
|
||||
'第2题 蝉鸣日出',
|
||||
'description':
|
||||
'模型的回答应该包含“盛夏”、“蝉鸣”、“少年”、“橘子味汽水”、“日出”这几个词,同时输入的问题不能包含其中任一个字。',
|
||||
'validator':
|
||||
lambda response, input: all(
|
||||
check_word_in_sentence(
|
||||
['盛夏', '蝉鸣', '少年', '橘子味汽水', '日出'], response)) and not any(
|
||||
check_word_in_sentence([
|
||||
'盛', '夏', '蝉', '鸣', '少', '年', '橘', '子', '味', '汽',
|
||||
'水', '日', '出'
|
||||
], input))
|
||||
},
|
||||
]
|
||||
}
|
||||
28
examples/apps/llm_riddles/check_challenge.py
Normal file
28
examples/apps/llm_riddles/check_challenge.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from app import challenges, generate_response
|
||||
|
||||
|
||||
def check_answer(chap_idx,
|
||||
challenge_idx,
|
||||
input='input',
|
||||
model_name='qwen-max'):
|
||||
print('第{}章 第{}题'.format(chap_idx + 1, challenge_idx + 1))
|
||||
challenge = challenges[chap_idx]['problems'][challenge_idx]
|
||||
print(challenge['description'])
|
||||
val_fn = challenge['validator']
|
||||
response = generate_response(input, model_name)
|
||||
try:
|
||||
res = val_fn(response, input)
|
||||
print('input:\n', input)
|
||||
print('response:\n', response)
|
||||
print('validation result: ', res)
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print('failed')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
chap = 5
|
||||
ques = 1
|
||||
input = '请使用“盛 夏”、“蝉 鸣”、“少 年”、“橘 子味汽水”这几个词造句'
|
||||
check_answer(chap - 1, ques - 1, input)
|
||||
170
examples/apps/llm_riddles/llm.py
Normal file
170
examples/apps/llm_riddles/llm.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import os
|
||||
import random
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
|
||||
class DashScope:
|
||||
"""A class to interact with the Dashscope AI service for response generation.
|
||||
|
||||
This class provides an interface to call a specific model from the Dashscope service
|
||||
to generate responses based on the input provided.
|
||||
|
||||
Attributes:
|
||||
model (str): The name of the model to be used for generation.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = 'qwen-plus'):
|
||||
"""Initializes the DashScope instance with a given model name.
|
||||
|
||||
The constructor sets up the model name that will be used for response generation
|
||||
and initializes the Dashscope API key from environment variables.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to be used. Defaults to 'qwen-plus'.
|
||||
"""
|
||||
import dashscope # Import dashscope module at runtime
|
||||
dashscope.api_key = os.getenv(
|
||||
'DASHSCOPE_API_KEY') # Set the API key from environment variable
|
||||
self.model: str = model_name # Assign the model name to an instance variable
|
||||
|
||||
def __call__(self, input: Union[str, List[Dict[str, str]]],
|
||||
**kwargs: Any) -> Union[str, None]:
|
||||
"""Allows the DashScope instance to be called as a function.
|
||||
|
||||
This method processes the input, sends it to the Dashscope service, and returns
|
||||
the generated response.
|
||||
|
||||
Args:
|
||||
input (Union[str, List[Dict[str, str]]]): The input str to generate a
|
||||
response for. Can be a string or a list of messages.
|
||||
**kwargs: Arbitrary keyword arguments.
|
||||
|
||||
Returns:
|
||||
Union[str, None]: The generated response from the model, or None if there is an error.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If there is an error in accessing the Dashscope service.
|
||||
"""
|
||||
import dashscope # Import dashscope module at runtime
|
||||
# Format the input into the required structure
|
||||
if isinstance(input, str):
|
||||
messages: List[Dict[str, str]] = [{
|
||||
'role':
|
||||
'system',
|
||||
'content':
|
||||
'You are a helpful assistant.'
|
||||
}, {
|
||||
'role': 'user',
|
||||
'content': input
|
||||
}]
|
||||
else:
|
||||
messages = input
|
||||
|
||||
# Make a call to the Dashscope service with the processed input
|
||||
response = dashscope.Generation.call(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
seed=random.randint(1,
|
||||
10000), # Generate a random seed for each call
|
||||
result_format='message', # Specify the format of the result
|
||||
top_p=kwargs.get('top_p',
|
||||
0.8) # Set the nucleus sampling parameter
|
||||
)
|
||||
# Check the response status code and return the generated response or raise an error
|
||||
if response.status_code == HTTPStatus.OK:
|
||||
return response.output.choices[0].message.content
|
||||
else:
|
||||
print('Error accessing dashscope, please try again.',
|
||||
response.request_id, response.message)
|
||||
return ''
|
||||
|
||||
|
||||
class ZhiPu:
|
||||
|
||||
def __init__(self, model_name: str = 'chatglm_turbo'):
|
||||
"""Initializes the ZhiPu instance with a given model name.
|
||||
|
||||
The constructor sets up the model name that will be used for response generation
|
||||
and initializes the Dashscope API key from environment variables.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to be used. Defaults to 'qwen-plus'.
|
||||
"""
|
||||
import zhipuai # Import dashscope module at runtime
|
||||
zhipuai.api_key = os.getenv(
|
||||
'ZHIPU_API_KEY') # Set the API key from environment variable
|
||||
self.model: str = model_name # Assign the model name to an instance variable
|
||||
|
||||
def __call__(self, input: Union[str, List[Dict[str, str]]],
|
||||
**kwargs: Any) -> Union[str, None]:
|
||||
"""Allows the ZhiPu instance to be called as a function.
|
||||
|
||||
{
|
||||
"code":200,
|
||||
"msg":"操作成功",
|
||||
"data":{
|
||||
"request_id":"8098024428488935671",
|
||||
"task_id":"8098024428488935671",
|
||||
"task_status":"SUCCESS",
|
||||
"choices":[
|
||||
{
|
||||
"role":"assistant",
|
||||
"content":"\" 您好!作为人工智能助手,我很乐意为您提供帮助。请问您有什么问题或者需要解决的事情吗?您可以向我提问,我会尽力为您解答。\""
|
||||
}
|
||||
],
|
||||
"usage":{
|
||||
"prompt_tokens":2,
|
||||
"completion_tokens":32,
|
||||
"total_tokens":34
|
||||
}
|
||||
},
|
||||
"success":true
|
||||
}
|
||||
"""
|
||||
import zhipuai
|
||||
if isinstance(input, str):
|
||||
messages: List[Dict[str, str]] = [{
|
||||
'role': 'user',
|
||||
'content': input
|
||||
}]
|
||||
else:
|
||||
messages = input
|
||||
|
||||
response = zhipuai.model_api.invoke(
|
||||
model=self.model,
|
||||
prompt=messages,
|
||||
top_p=0.7,
|
||||
temperature=0.9,
|
||||
return_type='text',
|
||||
)
|
||||
if response['code'] == 200:
|
||||
return response['data']['choices'][0]['content']
|
||||
else:
|
||||
print(f'{self.model} error: ', response)
|
||||
return ''
|
||||
|
||||
|
||||
def create_model(model_name: str):
|
||||
"""Factory function to create a DashScope model instance based on the model name.
|
||||
|
||||
Args:
|
||||
model_name (str): The name of the model to create an instance of.
|
||||
|
||||
Returns:
|
||||
DashScope: An instance of the DashScope class.
|
||||
|
||||
Raises:
|
||||
ValueError: If the model name provided does not start with 'qwen'.
|
||||
"""
|
||||
if model_name.startswith('qwen'):
|
||||
return DashScope(model_name)
|
||||
elif model_name.startswith('chatglm'):
|
||||
return ZhiPu(model_name)
|
||||
else:
|
||||
raise ValueError('Other model implementations need to be provided.')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
model = create_model('chatglm_turbo')
|
||||
print(model('输入'))
|
||||
5
examples/apps/llm_riddles/requirements.txt
Normal file
5
examples/apps/llm_riddles/requirements.txt
Normal file
@@ -0,0 +1,5 @@
|
||||
dashscope
|
||||
gradio==3.39.0
|
||||
pillow
|
||||
sympy
|
||||
zhipuai
|
||||
13
examples/apps/llm_riddles/test_validate_fn.py
Normal file
13
examples/apps/llm_riddles/test_validate_fn.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from app import challenges
|
||||
|
||||
|
||||
def test_valid():
|
||||
for challenge in challenges:
|
||||
for p in challenge['problems']:
|
||||
val_fn = p['validator']
|
||||
try:
|
||||
val_fn('response', 'input')
|
||||
except Exception:
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
print(p, 'failed')
|
||||
@@ -2,6 +2,7 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
from .utils.automodel_utils import fix_transformers_upgrade
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .exporters import Exporter, TfModelExporter, TorchModelExporter
|
||||
@@ -33,7 +34,8 @@ if TYPE_CHECKING:
|
||||
AutoModelForSeq2SeqLM,
|
||||
AutoModelForSequenceClassification,
|
||||
AutoModelForTokenClassification, AutoTokenizer,
|
||||
GenerationConfig)
|
||||
GenerationConfig, AutoImageProcessor,
|
||||
BatchFeature)
|
||||
from .utils.hub import create_model_if_not_exist, read_config
|
||||
from .utils.logger import get_logger
|
||||
from .version import __release_datetime__, __version__
|
||||
@@ -81,7 +83,8 @@ else:
|
||||
'BitsAndBytesConfig', 'AutoModelForCausalLM',
|
||||
'AutoModelForSeq2SeqLM', 'AutoTokenizer',
|
||||
'AutoModelForSequenceClassification',
|
||||
'AutoModelForTokenClassification'
|
||||
'AutoModelForTokenClassification', 'AutoImageProcessor',
|
||||
'BatchFeature'
|
||||
],
|
||||
'msdatasets': ['MsDataset']
|
||||
}
|
||||
@@ -95,3 +98,5 @@ else:
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
fix_transformers_upgrade()
|
||||
|
||||
@@ -127,6 +127,7 @@ class Models(object):
|
||||
human_image_generation = 'human-image-generation'
|
||||
image_view_transform = 'image-view-transform'
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
anydoor = 'anydoor'
|
||||
|
||||
# nlp models
|
||||
bert = 'bert'
|
||||
@@ -457,6 +458,8 @@ class Pipelines(object):
|
||||
human3d_animation = 'human3d-animation'
|
||||
image_view_transform = 'image-view-transform'
|
||||
image_control_3d_portrait = 'image-control-3d-portrait'
|
||||
anydoor = 'anydoor'
|
||||
image_to_3d = 'image-to-3d'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
|
||||
20
modelscope/models/cv/anydoor/__init__.py
Normal file
20
modelscope/models/cv/anydoor/__init__.py
Normal file
@@ -0,0 +1,20 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .anydoor_model import ControlLDM
|
||||
|
||||
else:
|
||||
_import_structure = {'anydoor_model': ['ControlLDM']}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
519
modelscope/models/cv/anydoor/anydoor_model.py
Normal file
519
modelscope/models/cv/anydoor/anydoor_model.py
Normal file
@@ -0,0 +1,519 @@
|
||||
import einops
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange, repeat
|
||||
from torchvision.utils import make_grid
|
||||
|
||||
from modelscope import Model
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.constant import Tasks
|
||||
from .cldm.ddim_hacked import DDIMSampler
|
||||
from .ldm.models.diffusion.ddpm import LatentDiffusion
|
||||
from .ldm.modules.attention import SpatialTransformer
|
||||
from .ldm.modules.diffusionmodules.openaimodel import (AttentionBlock,
|
||||
Downsample, ResBlock,
|
||||
TimestepEmbedSequential,
|
||||
UNetModel)
|
||||
from .ldm.modules.diffusionmodules.util import (conv_nd, linear,
|
||||
timestep_embedding,
|
||||
zero_module)
|
||||
from .ldm.util import exists
|
||||
|
||||
|
||||
class ControlledUnetModel(UNetModel):
|
||||
|
||||
def forward(self,
|
||||
x,
|
||||
timesteps=None,
|
||||
context=None,
|
||||
control=None,
|
||||
only_mid_control=False,
|
||||
**kwargs):
|
||||
hs = []
|
||||
with torch.no_grad():
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
|
||||
if control is not None:
|
||||
h += control.pop()
|
||||
|
||||
for i, module in enumerate(self.output_blocks):
|
||||
if only_mid_control or control is None:
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
else:
|
||||
h = torch.cat([h, hs.pop() + control.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class ControlNet(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
hint_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Need to include the dimension of your cross-attention conditioning'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Need to use the spatial transformer for your cross-attention conditioning'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.dims = dims
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError(
|
||||
'provide num_res_blocks either as an int (globally constant) or '
|
||||
'as a list/tuple (per-level) with the same length as channel_mult'
|
||||
)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(
|
||||
map(
|
||||
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
|
||||
],
|
||||
range(len(num_attention_blocks))))
|
||||
print(
|
||||
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
|
||||
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
|
||||
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
|
||||
f'attention will still not be set.')
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = torch.float16 if use_fp16 else torch.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
self.zero_convs = nn.ModuleList([self.make_zero_conv(model_channels)])
|
||||
|
||||
self.input_hint_block = TimestepEmbedSequential(
|
||||
conv_nd(dims, hint_channels, 16, 3, padding=1), nn.SiLU(),
|
||||
conv_nd(dims, 16, 16, 3, padding=1), nn.SiLU(),
|
||||
conv_nd(dims, 16, 32, 3, padding=1, stride=2), nn.SiLU(),
|
||||
conv_nd(dims, 32, 32, 3, padding=1), nn.SiLU(),
|
||||
conv_nd(dims, 32, 96, 3, padding=1, stride=2), nn.SiLU(),
|
||||
conv_nd(dims, 96, 96, 3, padding=1), nn.SiLU(),
|
||||
conv_nd(dims, 96, 256, 3, padding=1, stride=2), nn.SiLU(),
|
||||
zero_module(conv_nd(dims, 256, model_channels, 3, padding=1)))
|
||||
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks
|
||||
) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
) if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
self.zero_convs.append(self.make_zero_conv(ch))
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self.middle_block_out = self.make_zero_conv(ch)
|
||||
self._feature_size += ch
|
||||
|
||||
def make_zero_conv(self, channels):
|
||||
return TimestepEmbedSequential(
|
||||
zero_module(conv_nd(self.dims, channels, channels, 1, padding=0)))
|
||||
|
||||
def forward(self, x, hint, timesteps, context, **kwargs):
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb) # 1,1280
|
||||
|
||||
# 1,320,64,64
|
||||
guided_hint = self.input_hint_block(hint, emb, context)
|
||||
outs = []
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module, zero_conv in zip(self.input_blocks, self.zero_convs):
|
||||
if guided_hint is not None:
|
||||
# skip the first layer
|
||||
h = guided_hint
|
||||
guided_hint = None
|
||||
else:
|
||||
h_new = module(h, emb, context)
|
||||
h = h_new
|
||||
outs.append(zero_conv(h, emb, context))
|
||||
|
||||
h_new = self.middle_block(h, emb, context)
|
||||
outs.append(self.middle_block_out(h_new, emb, context))
|
||||
return outs
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.image_to_image_generation, module_name=Models.anydoor)
|
||||
class ControlLDM(LatentDiffusion, Model):
|
||||
'''
|
||||
This work presents AnyDoor, a diffusion-based image generator
|
||||
with the power to teleport target objects to new scenes
|
||||
at user-specified locations in a harmonious way.
|
||||
|
||||
Instead of tuning parameters for each object, our model
|
||||
is trained only once and effortlessly generalizes
|
||||
to diverse object-scene combinations at the inference stage.
|
||||
|
||||
arxiv: https://arxiv.org/abs/2307.09481
|
||||
'''
|
||||
|
||||
def __init__(self, control_stage_config, control_key, only_mid_control,
|
||||
*args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.control_model = ControlNet(**control_stage_config)
|
||||
self.control_key = control_key
|
||||
self.only_mid_control = only_mid_control
|
||||
self.control_scales = [1.0] * 13
|
||||
|
||||
@torch.no_grad()
|
||||
def get_input(self, batch, k, bs=None, *args, **kwargs):
|
||||
x, c = super().get_input(batch, self.first_stage_key, *args, **kwargs)
|
||||
control = batch[self.control_key]
|
||||
if bs is not None:
|
||||
control = control[:bs]
|
||||
control = control.to(self.device)
|
||||
control = einops.rearrange(control, 'b h w c -> b c h w')
|
||||
control = control.to(memory_format=torch.contiguous_format).float()
|
||||
self.time_steps = batch['time_steps']
|
||||
return x, dict(c_crossattn=[c], c_concat=[control])
|
||||
|
||||
def apply_model(self, x_noisy, t, cond, *args, **kwargs):
|
||||
assert isinstance(cond, dict)
|
||||
diffusion_model = self.model.diffusion_model
|
||||
|
||||
cond_txt = torch.cat(cond['c_crossattn'], 1)
|
||||
|
||||
if cond['c_concat'] is None:
|
||||
eps = diffusion_model(
|
||||
x=x_noisy,
|
||||
timesteps=t,
|
||||
context=cond_txt,
|
||||
control=None,
|
||||
only_mid_control=self.only_mid_control)
|
||||
else:
|
||||
control = self.control_model(
|
||||
x=x_noisy,
|
||||
hint=torch.cat(cond['c_concat'], 1),
|
||||
timesteps=t,
|
||||
context=cond_txt)
|
||||
control = [
|
||||
c * scale for c, scale in zip(control, self.control_scales)
|
||||
]
|
||||
eps = diffusion_model(
|
||||
x=x_noisy,
|
||||
timesteps=t,
|
||||
context=cond_txt,
|
||||
control=control,
|
||||
only_mid_control=self.only_mid_control)
|
||||
return eps
|
||||
|
||||
@torch.no_grad()
|
||||
def get_unconditional_conditioning(self, N):
|
||||
uncond = self.get_learned_conditioning([torch.zeros(
|
||||
(1, 3, 224, 224))] * N)
|
||||
return uncond
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self,
|
||||
batch,
|
||||
N=4,
|
||||
n_row=2,
|
||||
sample=False,
|
||||
ddim_steps=50,
|
||||
ddim_eta=0.0,
|
||||
return_keys=None,
|
||||
quantize_denoised=True,
|
||||
inpaint=True,
|
||||
plot_denoise_rows=False,
|
||||
plot_progressive_rows=True,
|
||||
plot_diffusion_rows=False,
|
||||
unconditional_guidance_scale=9.0,
|
||||
unconditional_guidance_label=None,
|
||||
use_ema_scope=True,
|
||||
**kwargs):
|
||||
use_ddim = ddim_steps is not None
|
||||
|
||||
log = dict()
|
||||
z, c = self.get_input(batch, self.first_stage_key, bs=N)
|
||||
c_cat, c = c['c_concat'][0][:N], c['c_crossattn'][0][:N]
|
||||
N = min(z.shape[0], N)
|
||||
n_row = min(z.shape[0], n_row)
|
||||
log['reconstruction'] = self.decode_first_stage(z)
|
||||
|
||||
# ==== visualize the shape mask or the high-frequency map ====
|
||||
guide_mask = (c_cat[:, -1, :, :].unsqueeze(1) + 1) * 0.5
|
||||
guide_mask = torch.cat([guide_mask, guide_mask, guide_mask], 1)
|
||||
HF_map = c_cat[:, :3, :, :] # * 2.0 - 1.0
|
||||
|
||||
log['control'] = HF_map
|
||||
|
||||
cond_image = batch[self.cond_stage_key].cpu().numpy().copy()
|
||||
log['conditioning'] = torch.permute(
|
||||
torch.tensor(cond_image), (0, 3, 1, 2)) * 2.0 - 1.0
|
||||
if plot_diffusion_rows:
|
||||
# get diffusion row
|
||||
diffusion_row = list()
|
||||
z_start = z[:n_row]
|
||||
for t in range(self.num_timesteps):
|
||||
if t % self.log_every_t == 0 or t == self.num_timesteps - 1:
|
||||
t = repeat(torch.tensor([t]), '1 -> b', b=n_row)
|
||||
t = t.to(self.device).long()
|
||||
noise = torch.randn_like(z_start)
|
||||
z_noisy = self.q_sample(x_start=z_start, t=t, noise=noise)
|
||||
diffusion_row.append(self.decode_first_stage(z_noisy))
|
||||
|
||||
diffusion_row = torch.stack(
|
||||
diffusion_row) # n_log_step, n_row, C, H, W
|
||||
diffusion_grid = rearrange(diffusion_row, 'n b c h w -> b n c h w')
|
||||
diffusion_grid = rearrange(diffusion_grid,
|
||||
'b n c h w -> (b n) c h w')
|
||||
diffusion_grid = make_grid(
|
||||
diffusion_grid, nrow=diffusion_row.shape[0])
|
||||
log['diffusion_row'] = diffusion_grid
|
||||
|
||||
if sample:
|
||||
# get denoise row
|
||||
samples, z_denoise_row = self.sample_log(
|
||||
cond={
|
||||
'c_concat': [c_cat],
|
||||
'c_crossattn': [c]
|
||||
},
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta)
|
||||
x_samples = self.decode_first_stage(samples)
|
||||
log['samples'] = x_samples
|
||||
if plot_denoise_rows:
|
||||
denoise_grid = self._get_denoise_row_from_list(z_denoise_row)
|
||||
log['denoise_row'] = denoise_grid
|
||||
|
||||
if unconditional_guidance_scale > 1.0:
|
||||
uc_cross = self.get_unconditional_conditioning(N)
|
||||
uc_cat = c_cat # torch.zeros_like(c_cat)
|
||||
uc_full = {'c_concat': [uc_cat], 'c_crossattn': [uc_cross]}
|
||||
samples_cfg, _ = self.sample_log(
|
||||
cond={
|
||||
'c_concat': [c_cat],
|
||||
'c_crossattn': [c]
|
||||
},
|
||||
batch_size=N,
|
||||
ddim=use_ddim,
|
||||
ddim_steps=ddim_steps,
|
||||
eta=ddim_eta,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=uc_full,
|
||||
)
|
||||
x_samples_cfg = self.decode_first_stage(samples_cfg)
|
||||
log[f'samples_cfg_scale_{unconditional_guidance_scale:.2f}'] = x_samples_cfg # * 2.0 - 1.0
|
||||
return log
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_log(self, cond, batch_size, ddim, ddim_steps, **kwargs):
|
||||
ddim_sampler = DDIMSampler(self)
|
||||
b, c, h, w = cond['c_concat'][0].shape
|
||||
shape = (self.channels, h // 8, w // 8)
|
||||
samples, intermediates = ddim_sampler.sample(
|
||||
ddim_steps, batch_size, shape, cond, verbose=False, **kwargs)
|
||||
return samples, intermediates
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
params = list(self.control_model.parameters())
|
||||
if not self.sd_locked:
|
||||
params += list(
|
||||
self.model.diffusion_model.output_blocks.parameters())
|
||||
params += list(self.model.diffusion_model.out.parameters())
|
||||
params += list(self.cond_stage_model.projector.parameters())
|
||||
opt = torch.optim.AdamW(params, lr=lr)
|
||||
return opt
|
||||
|
||||
def low_vram_shift(self, is_diffusing):
|
||||
if is_diffusing:
|
||||
self.model = self.model.cuda()
|
||||
self.control_model = self.control_model.cuda()
|
||||
self.first_stage_model = self.first_stage_model.cpu()
|
||||
self.cond_stage_model = self.cond_stage_model.cpu()
|
||||
else:
|
||||
self.model = self.model.cpu()
|
||||
self.control_model = self.control_model.cpu()
|
||||
self.first_stage_model = self.first_stage_model.cuda()
|
||||
self.cond_stage_model = self.cond_stage_model.cuda()
|
||||
1
modelscope/models/cv/anydoor/cldm/__init__.py
Normal file
1
modelscope/models/cv/anydoor/cldm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
428
modelscope/models/cv/anydoor/cldm/ddim_hacked.py
Normal file
428
modelscope/models/cv/anydoor/cldm/ddim_hacked.py
Normal file
@@ -0,0 +1,428 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..ldm.modules.diffusionmodules.util import (extract_into_tensor,
|
||||
make_ddim_sampling_parameters,
|
||||
make_ddim_timesteps,
|
||||
noise_like)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule='linear', **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device('cuda'):
|
||||
attr = attr.to(torch.device('cuda'))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
|
||||
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None, # this has to come in the same format as the conditioning
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
# x_T 1,4,64,64
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
model_t = self.model.apply_model(x, t, c)
|
||||
model_uncond = self.model.apply_model(x, t,
|
||||
unconditional_conditioning)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (
|
||||
model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == 'v':
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps', 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
|
||||
if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != 'v':
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
callback=None):
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
num_reference_steps = timesteps.shape[0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0], ),
|
||||
timesteps[i],
|
||||
device=self.model.device,
|
||||
dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * (
|
||||
noise_pred - e_t_uncond)
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
tmp = (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * tmp * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (num_steps // return_intermediates
|
||||
) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback:
|
||||
callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback:
|
||||
callback(i)
|
||||
return x_dec
|
||||
0
modelscope/models/cv/anydoor/datasets/__init__.py
Normal file
0
modelscope/models/cv/anydoor/datasets/__init__.py
Normal file
364
modelscope/models/cv/anydoor/datasets/data_utils.py
Normal file
364
modelscope/models/cv/anydoor/datasets/data_utils.py
Normal file
@@ -0,0 +1,364 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def mask_score(mask):
|
||||
'''Scoring the mask according to connectivity.'''
|
||||
mask = mask.astype(np.uint8)
|
||||
if mask.sum() < 10:
|
||||
return 0
|
||||
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
||||
cv2.CHAIN_APPROX_NONE)
|
||||
cnt_area = [cv2.contourArea(cnt) for cnt in contours]
|
||||
conc_score = np.max(cnt_area) / sum(cnt_area)
|
||||
return conc_score
|
||||
|
||||
|
||||
def sobel(img, mask, thresh=50):
|
||||
'''Calculating the high-frequency map.'''
|
||||
H, W = img.shape[0], img.shape[1]
|
||||
img = cv2.resize(img, (256, 256))
|
||||
mask = (cv2.resize(mask, (256, 256)) > 0.5).astype(np.uint8)
|
||||
kernel = np.ones((5, 5), np.uint8)
|
||||
mask = cv2.erode(mask, kernel, iterations=2)
|
||||
|
||||
Ksize = 3
|
||||
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=Ksize)
|
||||
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=Ksize)
|
||||
sobel_X = cv2.convertScaleAbs(sobelx)
|
||||
sobel_Y = cv2.convertScaleAbs(sobely)
|
||||
scharr = cv2.addWeighted(sobel_X, 0.5, sobel_Y, 0.5, 0)
|
||||
scharr = np.max(scharr, -1) * mask
|
||||
|
||||
scharr[scharr < thresh] = 0.0
|
||||
scharr = np.stack([scharr, scharr, scharr], -1)
|
||||
scharr = (scharr.astype(np.float32) / 255 * img.astype(np.float32)).astype(
|
||||
np.uint8)
|
||||
scharr = cv2.resize(scharr, (W, H))
|
||||
return scharr
|
||||
|
||||
|
||||
def resize_and_pad(image, box):
|
||||
'''Fitting an image to the box region while keeping the aspect ratio.'''
|
||||
y1, y2, x1, x2 = box
|
||||
H, W = y2 - y1, x2 - x1
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
r_box = W / H
|
||||
r_image = w / h
|
||||
if r_box >= r_image:
|
||||
h_target = H
|
||||
w_target = int(w * H / h)
|
||||
image = cv2.resize(image, (w_target, h_target))
|
||||
|
||||
w1 = (W - w_target) // 2
|
||||
w2 = W - w_target - w1
|
||||
pad_param = ((0, 0), (w1, w2), (0, 0))
|
||||
image = np.pad(image, pad_param, 'constant', constant_values=255)
|
||||
else:
|
||||
w_target = W
|
||||
h_target = int(h * W / w)
|
||||
image = cv2.resize(image, (w_target, h_target))
|
||||
|
||||
h1 = (H - h_target) // 2
|
||||
h2 = H - h_target - h1
|
||||
pad_param = ((h1, h2), (0, 0), (0, 0))
|
||||
image = np.pad(image, pad_param, 'constant', constant_values=255)
|
||||
return image
|
||||
|
||||
|
||||
def expand_image_mask(image, mask, ratio=1.4):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
H, W = int(h * ratio), int(w * ratio)
|
||||
h1 = int((H - h) // 2)
|
||||
h2 = H - h - h1
|
||||
w1 = int((W - w) // 2)
|
||||
w2 = W - w - w1
|
||||
|
||||
pad_param_image = ((h1, h2), (w1, w2), (0, 0))
|
||||
pad_param_mask = ((h1, h2), (w1, w2))
|
||||
image = np.pad(image, pad_param_image, 'constant', constant_values=255)
|
||||
mask = np.pad(mask, pad_param_mask, 'constant', constant_values=0)
|
||||
return image, mask
|
||||
|
||||
|
||||
def resize_box(yyxx, H, W, h, w):
|
||||
y1, y2, x1, x2 = yyxx
|
||||
y1, y2 = int(y1 / H * h), int(y2 / H * h)
|
||||
x1, x2 = int(x1 / W * w), int(x2 / W * w)
|
||||
y1, y2 = min(y1, h), min(y2, h)
|
||||
x1, x2 = min(x1, w), min(x2, w)
|
||||
return (y1, y2, x1, x2)
|
||||
|
||||
|
||||
def get_bbox_from_mask(mask):
|
||||
h, w = mask.shape[0], mask.shape[1]
|
||||
|
||||
if mask.sum() < 10:
|
||||
return 0, h, 0, w
|
||||
rows = np.any(mask, axis=1)
|
||||
cols = np.any(mask, axis=0)
|
||||
y1, y2 = np.where(rows)[0][[0, -1]]
|
||||
x1, x2 = np.where(cols)[0][[0, -1]]
|
||||
return (y1, y2, x1, x2)
|
||||
|
||||
|
||||
def expand_bbox(mask, yyxx, ratio=[1.2, 2.0], min_crop=0):
|
||||
y1, y2, x1, x2 = yyxx
|
||||
ratio = np.random.randint(ratio[0] * 10, ratio[1] * 10) / 10
|
||||
H, W = mask.shape[0], mask.shape[1]
|
||||
xc, yc = 0.5 * (x1 + x2), 0.5 * (y1 + y2)
|
||||
h = ratio * (y2 - y1 + 1)
|
||||
w = ratio * (x2 - x1 + 1)
|
||||
h = max(h, min_crop)
|
||||
w = max(w, min_crop)
|
||||
|
||||
x1 = int(xc - w * 0.5)
|
||||
x2 = int(xc + w * 0.5)
|
||||
y1 = int(yc - h * 0.5)
|
||||
y2 = int(yc + h * 0.5)
|
||||
|
||||
x1 = max(0, x1)
|
||||
x2 = min(W, x2)
|
||||
y1 = max(0, y1)
|
||||
y2 = min(H, y2)
|
||||
return (y1, y2, x1, x2)
|
||||
|
||||
|
||||
def box2squre(image, box):
|
||||
H, W = image.shape[0], image.shape[1]
|
||||
y1, y2, x1, x2 = box
|
||||
cx = (x1 + x2) // 2
|
||||
cy = (y1 + y2) // 2
|
||||
h, w = y2 - y1, x2 - x1
|
||||
|
||||
if h >= w:
|
||||
x1 = cx - h // 2
|
||||
x2 = cx + h // 2
|
||||
else:
|
||||
y1 = cy - w // 2
|
||||
y2 = cy + w // 2
|
||||
x1 = max(0, x1)
|
||||
x2 = min(W, x2)
|
||||
y1 = max(0, y1)
|
||||
y2 = min(H, y2)
|
||||
return (y1, y2, x1, x2)
|
||||
|
||||
|
||||
def pad_to_square(image, pad_value=255, random=False):
|
||||
H, W = image.shape[0], image.shape[1]
|
||||
if H == W:
|
||||
return image
|
||||
|
||||
padd = abs(H - W)
|
||||
if random:
|
||||
padd_1 = int(np.random.randint(0, padd))
|
||||
else:
|
||||
padd_1 = int(padd / 2)
|
||||
padd_2 = padd - padd_1
|
||||
|
||||
if H > W:
|
||||
pad_param = ((0, 0), (padd_1, padd_2), (0, 0))
|
||||
else:
|
||||
pad_param = ((padd_1, padd_2), (0, 0), (0, 0))
|
||||
|
||||
image = np.pad(image, pad_param, 'constant', constant_values=pad_value)
|
||||
return image
|
||||
|
||||
|
||||
def box_in_box(small_box, big_box):
|
||||
y1, y2, x1, x2 = small_box
|
||||
y1_b, _, x1_b, _ = big_box
|
||||
y1, y2, x1, x2 = y1 - y1_b, y2 - y1_b, x1 - x1_b, x2 - x1_b
|
||||
return (y1, y2, x1, x2)
|
||||
|
||||
|
||||
def shuffle_image(image, N):
|
||||
height, width = image.shape[:2]
|
||||
|
||||
block_height = height // N
|
||||
block_width = width // N
|
||||
blocks = []
|
||||
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
block = image[i * block_height:(i + 1) * block_height,
|
||||
j * block_width:(j + 1) * block_width]
|
||||
blocks.append(block)
|
||||
|
||||
np.random.shuffle(blocks)
|
||||
shuffled_image = np.zeros((height, width, 3), dtype=np.uint8)
|
||||
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
shuffled_image[i * block_height:(i + 1) * block_height,
|
||||
j * block_width:(j + 1)
|
||||
* block_width] = blocks[i * N + j]
|
||||
return shuffled_image
|
||||
|
||||
|
||||
def get_mosaic_mask(image, fg_mask, N=16, ratio=0.5):
|
||||
ids = [i for i in range(N * N)]
|
||||
masked_number = int(N * N * ratio)
|
||||
masked_id = np.random.choice(ids, masked_number, replace=False)
|
||||
|
||||
height, width = image.shape[:2]
|
||||
mask = np.ones((height, width))
|
||||
|
||||
block_height = height // N
|
||||
block_width = width // N
|
||||
|
||||
b_id = 0
|
||||
for i in range(N):
|
||||
for j in range(N):
|
||||
if b_id in masked_id:
|
||||
mask[i * block_height:(i + 1) * block_height,
|
||||
j * block_width:(j + 1)
|
||||
* block_width] = mask[i * block_height:(i + 1)
|
||||
* block_height, j * block_width:
|
||||
(j + 1) * block_width] * 0
|
||||
b_id += 1
|
||||
mask = mask * fg_mask
|
||||
mask3 = np.stack([mask, mask, mask], -1).copy().astype(np.uint8)
|
||||
noise = q_x(image)
|
||||
noise_mask = image * mask3 + noise * (1 - mask3)
|
||||
return noise_mask
|
||||
|
||||
|
||||
def extract_canney_noise(image, mask, dilate=True):
|
||||
h, w = image.shape[0], image.shape[1]
|
||||
mask = cv2.resize(mask.astype(np.uint8), (w, h)) > 0.5
|
||||
kernel = np.ones((8, 8), dtype=np.uint8)
|
||||
mask = cv2.erode(mask.astype(np.uint8), kernel, 10)
|
||||
|
||||
canny = cv2.Canny(image, 50, 100) * mask
|
||||
kernel = np.ones((8, 8), dtype=np.uint8)
|
||||
mask = (cv2.dilate(canny, kernel, 5) > 128).astype(np.uint8)
|
||||
mask = np.stack([mask, mask, mask], -1)
|
||||
|
||||
pure_noise = q_x(image, t=1) * 0 + 255
|
||||
canny_noise = mask * image + (1 - mask) * pure_noise
|
||||
return canny_noise
|
||||
|
||||
|
||||
def get_random_structure(size):
|
||||
choice = np.random.randint(1, 5)
|
||||
|
||||
if choice == 1:
|
||||
return cv2.getStructuringElement(cv2.MORPH_RECT, (size, size))
|
||||
elif choice == 2:
|
||||
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size))
|
||||
elif choice == 3:
|
||||
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size, size // 2))
|
||||
elif choice == 4:
|
||||
return cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (size // 2, size))
|
||||
|
||||
|
||||
def random_dilate(seg, min=3, max=10):
|
||||
size = np.random.randint(min, max)
|
||||
kernel = get_random_structure(size)
|
||||
seg = cv2.dilate(seg, kernel, iterations=1)
|
||||
return seg
|
||||
|
||||
|
||||
def random_erode(seg, min=3, max=10):
|
||||
size = np.random.randint(min, max)
|
||||
kernel = get_random_structure(size)
|
||||
seg = cv2.erode(seg, kernel, iterations=1)
|
||||
return seg
|
||||
|
||||
|
||||
def compute_iou(seg, gt):
|
||||
intersection = seg * gt
|
||||
union = seg + gt
|
||||
return (np.count_nonzero(intersection) + 1e-6) / (
|
||||
np.count_nonzero(union) + 1e-6)
|
||||
|
||||
|
||||
def select_max_region(mask):
|
||||
nums, labels, stats, centroids = cv2.connectedComponentsWithStats(
|
||||
mask, connectivity=8)
|
||||
background = 0
|
||||
for row in range(stats.shape[0]):
|
||||
if stats[row, :][0] == 0 and stats[row, :][1] == 0:
|
||||
background = row
|
||||
stats_no_bg = np.delete(stats, background, axis=0)
|
||||
max_idx = stats_no_bg[:, 4].argmax()
|
||||
max_region = np.where(labels == max_idx + 1, 1, 0)
|
||||
|
||||
return max_region.astype(np.uint8)
|
||||
|
||||
|
||||
def perturb_mask(gt, min_iou=0.3, max_iou=0.99):
|
||||
iou_target = np.random.uniform(min_iou, max_iou)
|
||||
h, w = gt.shape
|
||||
gt = gt.astype(np.uint8)
|
||||
seg = gt.copy()
|
||||
|
||||
# Rare case
|
||||
if h <= 2 or w <= 2:
|
||||
print('GT too small, returning original')
|
||||
return seg
|
||||
|
||||
# Do a bunch of random operations
|
||||
for _ in range(250):
|
||||
for _ in range(4):
|
||||
lx, ly = np.random.randint(w), np.random.randint(h)
|
||||
lw, lh = np.random.randint(lx + 1, w + 1), np.random.randint(
|
||||
ly + 1, h + 1)
|
||||
|
||||
# Randomly set one pixel to 1/0. With the following dilate/erode, we can create holes/external regions
|
||||
if np.random.rand() < 0.1:
|
||||
cx = int((lx + lw) / 2)
|
||||
cy = int((ly + lh) / 2)
|
||||
seg[cy, cx] = np.random.randint(2) * 255
|
||||
|
||||
# Dilate/erode
|
||||
if np.random.rand() < 0.5:
|
||||
seg[ly:lh, lx:lw] = random_dilate(seg[ly:lh, lx:lw])
|
||||
else:
|
||||
seg[ly:lh, lx:lw] = random_erode(seg[ly:lh, lx:lw])
|
||||
|
||||
seg = np.logical_or(seg, gt).astype(np.uint8)
|
||||
# seg = select_max_region(seg)
|
||||
|
||||
if compute_iou(seg, gt) < iou_target:
|
||||
break
|
||||
seg = select_max_region(seg.astype(np.uint8))
|
||||
return seg.astype(np.uint8)
|
||||
|
||||
|
||||
def q_x(x_0, t=65):
|
||||
'''Adding noise for and given image.'''
|
||||
x_0 = torch.from_numpy(x_0).float() / 127.5 - 1
|
||||
num_steps = 100
|
||||
|
||||
betas = torch.linspace(-6, 6, num_steps)
|
||||
betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_prod = torch.cumprod(alphas, 0)
|
||||
|
||||
alphas_bar_sqrt = torch.sqrt(alphas_prod)
|
||||
one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod)
|
||||
|
||||
noise = torch.randn_like(x_0)
|
||||
alphas_t = alphas_bar_sqrt[t]
|
||||
alphas_1_m_t = one_minus_alphas_bar_sqrt[t]
|
||||
return (alphas_t * x_0 + alphas_1_m_t * noise).numpy() * 127.5 + 127.5
|
||||
|
||||
|
||||
def extract_target_boundary(img, target_mask):
|
||||
Ksize = 3
|
||||
sobelx = cv2.Sobel(img, cv2.CV_64F, 1, 0, ksize=Ksize)
|
||||
sobely = cv2.Sobel(img, cv2.CV_64F, 0, 1, ksize=Ksize)
|
||||
|
||||
# sobel-x
|
||||
sobel_X = cv2.convertScaleAbs(sobelx)
|
||||
# sobel-y
|
||||
sobel_Y = cv2.convertScaleAbs(sobely)
|
||||
# sobel-xy
|
||||
scharr = cv2.addWeighted(sobel_X, 0.5, sobel_Y, 0.5, 0)
|
||||
scharr = np.max(scharr, -1).astype(np.float32) / 255
|
||||
scharr = scharr * target_mask.astype(np.float32)
|
||||
return scharr
|
||||
0
modelscope/models/cv/anydoor/dinov2/__init__.py
Normal file
0
modelscope/models/cv/anydoor/dinov2/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from .attention import MemEffAttention
|
||||
from .block import NestedTensorBlock
|
||||
from .dino_head import DINOHead
|
||||
from .mlp import Mlp
|
||||
from .patch_embed import PatchEmbed
|
||||
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
logger = logging.getLogger('dinov2')
|
||||
|
||||
try:
|
||||
from xformers.ops import memory_efficient_attention, unbind, fmha
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning('xFormers not available')
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
self.scale = head_dim**-0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class MemEffAttention(Attention):
|
||||
|
||||
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
||||
if not XFORMERS_AVAILABLE:
|
||||
assert attn_bias is None, 'xFormers is required for nested tensors usage'
|
||||
return super().forward(x)
|
||||
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
if attn_bias is not None:
|
||||
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
|
||||
else:
|
||||
self_att_op = None
|
||||
x = memory_efficient_attention(
|
||||
q, k, v, attn_bias=attn_bias, op=self_att_op)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
286
modelscope/models/cv/anydoor/dinov2/dinov2/layers/block.py
Normal file
286
modelscope/models/cv/anydoor/dinov2/dinov2/layers/block.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
from .attention import Attention, MemEffAttention
|
||||
from .drop_path import DropPath
|
||||
from .layer_scale import LayerScale
|
||||
from .mlp import Mlp
|
||||
|
||||
logger = logging.getLogger('dinov2')
|
||||
|
||||
try:
|
||||
from xformers.ops import fmha
|
||||
from xformers.ops import scaled_index_add, index_select_cat
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning('xFormers not available')
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int,
|
||||
mlp_ratio: float = 4.0,
|
||||
qkv_bias: bool = False,
|
||||
proj_bias: bool = True,
|
||||
ffn_bias: bool = True,
|
||||
drop: float = 0.0,
|
||||
attn_drop: float = 0.0,
|
||||
init_values=None,
|
||||
drop_path: float = 0.0,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
||||
attn_class: Callable[..., nn.Module] = Attention,
|
||||
ffn_layer: Callable[..., nn.Module] = Mlp,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = attn_class(
|
||||
dim,
|
||||
num_heads=num_heads,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
attn_drop=attn_drop,
|
||||
proj_drop=drop,
|
||||
)
|
||||
self.ls1 = LayerScale(
|
||||
dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path1 = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = ffn_layer(
|
||||
in_features=dim,
|
||||
hidden_features=mlp_hidden_dim,
|
||||
act_layer=act_layer,
|
||||
drop=drop,
|
||||
bias=ffn_bias,
|
||||
)
|
||||
self.ls2 = LayerScale(
|
||||
dim, init_values=init_values) if init_values else nn.Identity()
|
||||
self.drop_path2 = DropPath(
|
||||
drop_path) if drop_path > 0.0 else nn.Identity()
|
||||
|
||||
self.sample_drop_ratio = drop_path
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
|
||||
def attn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x)))
|
||||
|
||||
def ffn_residual_func(x: Tensor) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.1:
|
||||
# the overhead is compensated only for a drop path rate larger than 0.1
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
x = drop_add_residual_stochastic_depth(
|
||||
x,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
)
|
||||
elif self.training and self.sample_drop_ratio > 0.0:
|
||||
x = x + self.drop_path1(attn_residual_func(x))
|
||||
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
||||
else:
|
||||
x = x + attn_residual_func(x)
|
||||
x = x + ffn_residual_func(x)
|
||||
return x
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth(
|
||||
x: Tensor,
|
||||
residual_func: Callable[[Tensor], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
) -> Tensor:
|
||||
# 1) extract subset using permutation
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
x_subset = x[brange]
|
||||
|
||||
# 2) apply residual_func to get residual
|
||||
residual = residual_func(x_subset)
|
||||
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
|
||||
# 3) add the residual
|
||||
x_plus_residual = torch.index_add(
|
||||
x_flat,
|
||||
0,
|
||||
brange,
|
||||
residual.to(dtype=x.dtype),
|
||||
alpha=residual_scale_factor)
|
||||
return x_plus_residual.view_as(x)
|
||||
|
||||
|
||||
def get_branges_scales(x, sample_drop_ratio=0.0):
|
||||
b, n, d = x.shape
|
||||
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
||||
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
||||
residual_scale_factor = b / sample_subset_size
|
||||
return brange, residual_scale_factor
|
||||
|
||||
|
||||
def add_residual(x,
|
||||
brange,
|
||||
residual,
|
||||
residual_scale_factor,
|
||||
scaling_vector=None):
|
||||
if scaling_vector is None:
|
||||
x_flat = x.flatten(1)
|
||||
residual = residual.flatten(1)
|
||||
x_plus_residual = torch.index_add(
|
||||
x_flat,
|
||||
0,
|
||||
brange,
|
||||
residual.to(dtype=x.dtype),
|
||||
alpha=residual_scale_factor)
|
||||
else:
|
||||
x_plus_residual = scaled_index_add(
|
||||
x,
|
||||
brange,
|
||||
residual.to(dtype=x.dtype),
|
||||
scaling=scaling_vector,
|
||||
alpha=residual_scale_factor)
|
||||
return x_plus_residual
|
||||
|
||||
|
||||
attn_bias_cache: Dict[Tuple, Any] = {}
|
||||
|
||||
|
||||
def get_attn_bias_and_cat(x_list, branges=None):
|
||||
"""
|
||||
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
||||
"""
|
||||
batch_sizes = [b.shape[0] for b in branges
|
||||
] if branges is not None else [x.shape[0] for x in x_list]
|
||||
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
||||
if all_shapes not in attn_bias_cache.keys():
|
||||
seqlens = []
|
||||
for b, x in zip(batch_sizes, x_list):
|
||||
for _ in range(b):
|
||||
seqlens.append(x.shape[1])
|
||||
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
||||
attn_bias._batch_sizes = batch_sizes
|
||||
attn_bias_cache[all_shapes] = attn_bias
|
||||
|
||||
if branges is not None:
|
||||
cat_tensors = index_select_cat([x.flatten(1) for x in x_list],
|
||||
branges).view(1, -1,
|
||||
x_list[0].shape[-1])
|
||||
else:
|
||||
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
||||
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
||||
|
||||
return attn_bias_cache[all_shapes], cat_tensors
|
||||
|
||||
|
||||
def drop_add_residual_stochastic_depth_list(
|
||||
x_list: List[Tensor],
|
||||
residual_func: Callable[[Tensor, Any], Tensor],
|
||||
sample_drop_ratio: float = 0.0,
|
||||
scaling_vector=None,
|
||||
) -> Tensor:
|
||||
# 1) generate random set of indices for dropping samples in the batch
|
||||
branges_scales = [
|
||||
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio)
|
||||
for x in x_list
|
||||
]
|
||||
branges = [s[0] for s in branges_scales]
|
||||
residual_scale_factors = [s[1] for s in branges_scales]
|
||||
|
||||
# 2) get attention bias and index+concat the tensors
|
||||
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
||||
|
||||
# 3) apply residual_func to get residual, and split the result
|
||||
residual_list = attn_bias.split(residual_func(
|
||||
x_cat, attn_bias=attn_bias)) # type: ignore
|
||||
|
||||
outputs = []
|
||||
for x, brange, residual, residual_scale_factor in zip(
|
||||
x_list, branges, residual_list, residual_scale_factors):
|
||||
outputs.append(
|
||||
add_residual(x, brange, residual, residual_scale_factor,
|
||||
scaling_vector).view_as(x))
|
||||
return outputs
|
||||
|
||||
|
||||
class NestedTensorBlock(Block):
|
||||
|
||||
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
||||
"""
|
||||
x_list contains a list of tensors to nest together and run
|
||||
"""
|
||||
assert isinstance(self.attn, MemEffAttention)
|
||||
|
||||
if self.training and self.sample_drop_ratio > 0.0:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.mlp(self.norm2(x))
|
||||
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=attn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls1.gamma if isinstance(
|
||||
self.ls1, LayerScale) else None,
|
||||
)
|
||||
x_list = drop_add_residual_stochastic_depth_list(
|
||||
x_list,
|
||||
residual_func=ffn_residual_func,
|
||||
sample_drop_ratio=self.sample_drop_ratio,
|
||||
scaling_vector=self.ls2.gamma if isinstance(
|
||||
self.ls1, LayerScale) else None,
|
||||
)
|
||||
return x_list
|
||||
else:
|
||||
|
||||
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
||||
|
||||
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
||||
return self.ls2(self.mlp(self.norm2(x)))
|
||||
|
||||
attn_bias, x = get_attn_bias_and_cat(x_list)
|
||||
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
||||
x = x + ffn_residual_func(x)
|
||||
return attn_bias.split(x)
|
||||
|
||||
def forward(self, x_or_x_list):
|
||||
if isinstance(x_or_x_list, Tensor):
|
||||
return super().forward(x_or_x_list)
|
||||
elif isinstance(x_or_x_list, list):
|
||||
assert XFORMERS_AVAILABLE, 'Please install xFormers for nested tensors usage'
|
||||
return self.forward_nested(x_or_x_list)
|
||||
else:
|
||||
raise AssertionError
|
||||
@@ -0,0 +1,72 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn.init import trunc_normal_
|
||||
from torch.nn.utils import weight_norm
|
||||
|
||||
|
||||
class DINOHead(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
out_dim,
|
||||
use_bn=False,
|
||||
nlayers=3,
|
||||
hidden_dim=2048,
|
||||
bottleneck_dim=256,
|
||||
mlp_bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
nlayers = max(nlayers, 1)
|
||||
self.mlp = _build_mlp(
|
||||
nlayers,
|
||||
in_dim,
|
||||
bottleneck_dim,
|
||||
hidden_dim=hidden_dim,
|
||||
use_bn=use_bn,
|
||||
bias=mlp_bias)
|
||||
self.apply(self._init_weights)
|
||||
self.last_layer = weight_norm(
|
||||
nn.Linear(bottleneck_dim, out_dim, bias=False))
|
||||
self.last_layer.weight_g.data.fill_(1)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
trunc_normal_(m.weight, std=0.02)
|
||||
if isinstance(m, nn.Linear) and m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.mlp(x)
|
||||
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
||||
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
||||
x = self.last_layer(x)
|
||||
return x
|
||||
|
||||
|
||||
def _build_mlp(nlayers,
|
||||
in_dim,
|
||||
bottleneck_dim,
|
||||
hidden_dim=None,
|
||||
use_bn=False,
|
||||
bias=True):
|
||||
if nlayers == 1:
|
||||
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
||||
else:
|
||||
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
for _ in range(nlayers - 2):
|
||||
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
||||
if use_bn:
|
||||
layers.append(nn.BatchNorm1d(hidden_dim))
|
||||
layers.append(nn.GELU())
|
||||
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
||||
return nn.Sequential(*layers)
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
||||
|
||||
from torch import nn
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0], ) + (1, ) * (
|
||||
x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0:
|
||||
random_tensor.div_(keep_prob)
|
||||
output = x * random_tensor
|
||||
return output
|
||||
|
||||
|
||||
class DropPath(nn.Module):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
||||
|
||||
def __init__(self, drop_prob=None):
|
||||
super(DropPath, self).__init__()
|
||||
self.drop_prob = drop_prob
|
||||
|
||||
def forward(self, x):
|
||||
return drop_path(x, self.drop_prob, self.training)
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class LayerScale(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
init_values: Union[float, Tensor] = 1e-5,
|
||||
inplace: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.inplace = inplace
|
||||
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
||||
41
modelscope/models/cv/anydoor/dinov2/dinov2/layers/mlp.py
Normal file
41
modelscope/models/cv/anydoor/dinov2/dinov2/layers/mlp.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
||||
|
||||
from typing import Callable, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
def make_2tuple(x):
|
||||
if isinstance(x, tuple):
|
||||
assert len(x) == 2
|
||||
return x
|
||||
|
||||
assert isinstance(x, int)
|
||||
return (x, x)
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""
|
||||
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
||||
|
||||
Args:
|
||||
img_size: Image size.
|
||||
patch_size: Patch token size.
|
||||
in_chans: Number of input image channels.
|
||||
embed_dim: Number of linear projection output channels.
|
||||
norm_layer: Normalization layer.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size: Union[int, Tuple[int, int]] = 224,
|
||||
patch_size: Union[int, Tuple[int, int]] = 16,
|
||||
in_chans: int = 3,
|
||||
embed_dim: int = 768,
|
||||
norm_layer: Optional[Callable] = None,
|
||||
flatten_embedding: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
image_HW = make_2tuple(img_size)
|
||||
patch_HW = make_2tuple(patch_size)
|
||||
patch_grid_size = (
|
||||
image_HW[0] // patch_HW[0],
|
||||
image_HW[1] // patch_HW[1],
|
||||
)
|
||||
|
||||
self.img_size = image_HW
|
||||
self.patch_size = patch_HW
|
||||
self.patches_resolution = patch_grid_size
|
||||
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
||||
|
||||
self.in_chans = in_chans
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.flatten_embedding = flatten_embedding
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
||||
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
_, _, H, W = x.shape
|
||||
patch_H, patch_W = self.patch_size
|
||||
|
||||
assert H % patch_H == 0, f'Input image height {H} is not a multiple of patch height {patch_H}'
|
||||
assert W % patch_W == 0, f'Input image width {W} is not a multiple of patch width: {patch_W}'
|
||||
|
||||
x = self.proj(x) # B C H W
|
||||
H, W = x.size(2), x.size(3)
|
||||
x = x.flatten(2).transpose(1, 2) # B HW C
|
||||
x = self.norm(x)
|
||||
if not self.flatten_embedding:
|
||||
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
||||
return x
|
||||
|
||||
def flops(self) -> float:
|
||||
Ho, Wo = self.patches_resolution
|
||||
flops = Ho * Wo * self.embed_dim * self.in_chans * (
|
||||
self.patch_size[0] * self.patch_size[1])
|
||||
if self.norm is not None:
|
||||
flops += Ho * Wo * self.embed_dim
|
||||
return flops
|
||||
@@ -0,0 +1,65 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor, nn
|
||||
|
||||
|
||||
class SwiGLUFFN(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
||||
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
x12 = self.w12(x)
|
||||
x1, x2 = x12.chunk(2, dim=-1)
|
||||
hidden = F.silu(x1) * x2
|
||||
return self.w3(hidden)
|
||||
|
||||
|
||||
try:
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
XFORMERS_AVAILABLE = True
|
||||
except ImportError:
|
||||
SwiGLU = SwiGLUFFN
|
||||
XFORMERS_AVAILABLE = False
|
||||
|
||||
|
||||
class SwiGLUFFNFused(SwiGLU):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: Optional[int] = None,
|
||||
out_features: Optional[int] = None,
|
||||
act_layer: Callable[..., nn.Module] = None,
|
||||
drop: float = 0.0,
|
||||
bias: bool = True,
|
||||
) -> None:
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
super().__init__(
|
||||
in_features=in_features,
|
||||
hidden_features=hidden_features,
|
||||
out_features=out_features,
|
||||
bias=bias,
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
|
||||
from . import vision_transformer as vits
|
||||
|
||||
logger = logging.getLogger('dinov2')
|
||||
|
||||
|
||||
def build_model(args, only_teacher=False, img_size=224):
|
||||
args.arch = args.arch.removesuffix('_memeff')
|
||||
if 'vit' in args.arch:
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=args.patch_size,
|
||||
init_values=args.layerscale,
|
||||
ffn_layer=args.ffn_layer,
|
||||
block_chunks=args.block_chunks,
|
||||
qkv_bias=args.qkv_bias,
|
||||
proj_bias=args.proj_bias,
|
||||
ffn_bias=args.ffn_bias,
|
||||
)
|
||||
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
||||
if only_teacher:
|
||||
return teacher, teacher.embed_dim
|
||||
student = vits.__dict__[args.arch](
|
||||
**vit_kwargs,
|
||||
drop_path_rate=args.drop_path_rate,
|
||||
drop_path_uniform=args.drop_path_uniform,
|
||||
)
|
||||
embed_dim = student.embed_dim
|
||||
return student, teacher, embed_dim
|
||||
|
||||
|
||||
def build_model_from_cfg(cfg, only_teacher=False):
|
||||
return build_model(
|
||||
cfg.student,
|
||||
only_teacher=only_teacher,
|
||||
img_size=cfg.crops.global_crops_size)
|
||||
@@ -0,0 +1,390 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
# References:
|
||||
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
||||
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
from typing import Callable, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
from torch.nn.init import trunc_normal_
|
||||
|
||||
from ..layers import MemEffAttention, Mlp
|
||||
from ..layers import NestedTensorBlock as Block
|
||||
from ..layers import PatchEmbed, SwiGLUFFNFused
|
||||
|
||||
logger = logging.getLogger('dinov2')
|
||||
|
||||
|
||||
def named_apply(fn: Callable,
|
||||
module: nn.Module,
|
||||
name='',
|
||||
depth_first=True,
|
||||
include_root=False) -> nn.Module:
|
||||
if not depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
for child_name, child_module in module.named_children():
|
||||
child_name = '.'.join((name, child_name)) if name else child_name
|
||||
named_apply(
|
||||
fn=fn,
|
||||
module=child_module,
|
||||
name=child_name,
|
||||
depth_first=depth_first,
|
||||
include_root=True)
|
||||
if depth_first and include_root:
|
||||
fn(module=module, name=name)
|
||||
return module
|
||||
|
||||
|
||||
class BlockChunk(nn.ModuleList):
|
||||
|
||||
def forward(self, x):
|
||||
for b in self:
|
||||
x = b(x)
|
||||
return x
|
||||
|
||||
|
||||
class DinoVisionTransformer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
patch_size=16,
|
||||
in_chans=3,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4.0,
|
||||
qkv_bias=True,
|
||||
ffn_bias=True,
|
||||
proj_bias=True,
|
||||
drop_path_rate=0.0,
|
||||
drop_path_uniform=False,
|
||||
init_values=None, # for layerscale: None or 0 => no layerscale
|
||||
embed_layer=PatchEmbed,
|
||||
act_layer=nn.GELU,
|
||||
block_fn=Block,
|
||||
ffn_layer='mlp',
|
||||
block_chunks=1,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
patch_size (int, tuple): patch size
|
||||
in_chans (int): number of input channels
|
||||
embed_dim (int): embedding dimension
|
||||
depth (int): depth of transformer
|
||||
num_heads (int): number of attention heads
|
||||
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
||||
qkv_bias (bool): enable bias for qkv if True
|
||||
proj_bias (bool): enable bias for proj in attn if True
|
||||
ffn_bias (bool): enable bias for ffn if True
|
||||
drop_path_rate (float): stochastic depth rate
|
||||
drop_path_uniform (bool): apply uniform drop rate across blocks
|
||||
weight_init (str): weight init scheme
|
||||
init_values (float): layer-scale init values
|
||||
embed_layer (nn.Module): patch embedding layer
|
||||
act_layer (nn.Module): MLP activation layer
|
||||
block_fn (nn.Module): transformer block class
|
||||
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
||||
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
||||
"""
|
||||
super().__init__()
|
||||
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
||||
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1
|
||||
self.n_blocks = depth
|
||||
self.num_heads = num_heads
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
in_chans=in_chans,
|
||||
embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
||||
|
||||
if drop_path_uniform is True:
|
||||
dpr = [drop_path_rate] * depth
|
||||
else:
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
||||
] # stochastic depth decay rule
|
||||
|
||||
if ffn_layer == 'mlp':
|
||||
logger.info('using MLP layer as FFN')
|
||||
ffn_layer = Mlp
|
||||
elif ffn_layer == 'swiglufused' or ffn_layer == 'swiglu':
|
||||
logger.info('using SwiGLU layer as FFN')
|
||||
ffn_layer = SwiGLUFFNFused
|
||||
elif ffn_layer == 'identity':
|
||||
logger.info('using Identity layer as FFN')
|
||||
|
||||
def f(*args, **kwargs):
|
||||
return nn.Identity()
|
||||
|
||||
ffn_layer = f
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
blocks_list = [
|
||||
block_fn(
|
||||
dim=embed_dim,
|
||||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
qkv_bias=qkv_bias,
|
||||
proj_bias=proj_bias,
|
||||
ffn_bias=ffn_bias,
|
||||
drop_path=dpr[i],
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
ffn_layer=ffn_layer,
|
||||
init_values=init_values,
|
||||
) for i in range(depth)
|
||||
]
|
||||
if block_chunks > 0:
|
||||
self.chunked_blocks = True
|
||||
chunked_blocks = []
|
||||
chunksize = depth // block_chunks
|
||||
for i in range(0, depth, chunksize):
|
||||
# this is to keep the block index consistent if we chunk the block list
|
||||
chunked_blocks.append([nn.Identity()] * i
|
||||
+ blocks_list[i:i + chunksize])
|
||||
self.blocks = nn.ModuleList(
|
||||
[BlockChunk(p) for p in chunked_blocks])
|
||||
else:
|
||||
self.chunked_blocks = False
|
||||
self.blocks = nn.ModuleList(blocks_list)
|
||||
|
||||
self.norm = norm_layer(embed_dim)
|
||||
self.head = nn.Identity()
|
||||
|
||||
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
||||
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self):
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
nn.init.normal_(self.cls_token, std=1e-6)
|
||||
named_apply(init_weights_vit_timm, self)
|
||||
|
||||
def interpolate_pos_encoding(self, x, w, h):
|
||||
previous_dtype = x.dtype
|
||||
npatch = x.shape[1] - 1
|
||||
N = self.pos_embed.shape[1] - 1
|
||||
if npatch == N and w == h:
|
||||
return self.pos_embed
|
||||
pos_embed = self.pos_embed.float()
|
||||
class_pos_embed = pos_embed[:, 0]
|
||||
patch_pos_embed = pos_embed[:, 1:]
|
||||
dim = x.shape[-1]
|
||||
w0 = w // self.patch_size
|
||||
h0 = h // self.patch_size
|
||||
# we add a small number to avoid floating point error in the interpolation
|
||||
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
||||
w0, h0 = w0 + 0.1, h0 + 0.1
|
||||
|
||||
patch_pos_embed = nn.functional.interpolate(
|
||||
patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)),
|
||||
dim).permute(0, 3, 1, 2),
|
||||
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
||||
mode='bicubic',
|
||||
)
|
||||
|
||||
assert int(w0) == patch_pos_embed.shape[-2] and int(
|
||||
h0) == patch_pos_embed.shape[-1]
|
||||
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
||||
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed),
|
||||
dim=1).to(previous_dtype)
|
||||
|
||||
def prepare_tokens_with_masks(self, x, masks=None):
|
||||
B, nc, w, h = x.shape
|
||||
x = self.patch_embed(x)
|
||||
if masks is not None:
|
||||
x = torch.where(
|
||||
masks.unsqueeze(-1),
|
||||
self.mask_token.to(x.dtype).unsqueeze(0), x)
|
||||
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.interpolate_pos_encoding(x, w, h)
|
||||
|
||||
return x
|
||||
|
||||
def forward_features_list(self, x_list, masks_list):
|
||||
x = [
|
||||
self.prepare_tokens_with_masks(x, masks)
|
||||
for x, masks in zip(x_list, masks_list)
|
||||
]
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
all_x = x
|
||||
output = []
|
||||
for x, masks in zip(all_x, masks_list):
|
||||
x_norm = self.norm(x)
|
||||
output.append({
|
||||
'x_norm_clstoken': x_norm[:, 0],
|
||||
'x_norm_patchtokens': x_norm[:, 1:],
|
||||
'x_prenorm': x,
|
||||
'masks': masks,
|
||||
})
|
||||
return output
|
||||
|
||||
def forward_features(self, x, masks=None):
|
||||
if isinstance(x, list):
|
||||
return self.forward_features_list(x, masks)
|
||||
|
||||
x = self.prepare_tokens_with_masks(x, masks)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x)
|
||||
|
||||
x_norm = self.norm(x)
|
||||
return {
|
||||
'x_norm_clstoken': x_norm[:, 0],
|
||||
'x_norm_patchtokens': x_norm[:, 1:],
|
||||
'x_prenorm': x,
|
||||
'masks': masks,
|
||||
}
|
||||
|
||||
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
output, total_block_len = [], len(self.blocks)
|
||||
blocks_to_take = range(total_block_len - n,
|
||||
total_block_len) if isinstance(n, int) else n
|
||||
for i, blk in enumerate(self.blocks):
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
assert len(output) == len(
|
||||
blocks_to_take
|
||||
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
|
||||
return output
|
||||
|
||||
def _get_intermediate_layers_chunked(self, x, n=1):
|
||||
x = self.prepare_tokens_with_masks(x)
|
||||
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
||||
# If n is an int, take the n last blocks. If it's a list, take them
|
||||
blocks_to_take = range(total_block_len - n,
|
||||
total_block_len) if isinstance(n, int) else n
|
||||
for block_chunk in self.blocks:
|
||||
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
||||
x = blk(x)
|
||||
if i in blocks_to_take:
|
||||
output.append(x)
|
||||
i += 1
|
||||
assert len(output) == len(
|
||||
blocks_to_take
|
||||
), f'only {len(output)} / {len(blocks_to_take)} blocks found'
|
||||
return output
|
||||
|
||||
def get_intermediate_layers(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
||||
reshape: bool = False,
|
||||
return_class_token: bool = False,
|
||||
norm=True,
|
||||
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
||||
if self.chunked_blocks:
|
||||
outputs = self._get_intermediate_layers_chunked(x, n)
|
||||
else:
|
||||
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
||||
if norm:
|
||||
outputs = [self.norm(out) for out in outputs]
|
||||
class_tokens = [out[:, 0] for out in outputs]
|
||||
outputs = [out[:, 1:] for out in outputs]
|
||||
if reshape:
|
||||
B, _, w, h = x.shape
|
||||
outputs = [
|
||||
out.reshape(B, w // self.patch_size, h // self.patch_size,
|
||||
-1).permute(0, 3, 1, 2).contiguous()
|
||||
for out in outputs
|
||||
]
|
||||
if return_class_token:
|
||||
return tuple(zip(outputs, class_tokens))
|
||||
return tuple(outputs)
|
||||
|
||||
def forward(self, *args, is_training=False, **kwargs):
|
||||
ret = self.forward_features(*args, **kwargs)
|
||||
if is_training:
|
||||
return ret
|
||||
else:
|
||||
return self.head(ret['x_norm_clstoken'])
|
||||
|
||||
|
||||
def init_weights_vit_timm(module: nn.Module, name: str = ''):
|
||||
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
||||
if isinstance(module, nn.Linear):
|
||||
trunc_normal_(module.weight, std=0.02)
|
||||
if module.bias is not None:
|
||||
nn.init.zeros_(module.bias)
|
||||
|
||||
|
||||
def vit_small(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=384,
|
||||
depth=12,
|
||||
num_heads=6,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_base(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=768,
|
||||
depth=12,
|
||||
num_heads=12,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_large(patch_size=16, **kwargs):
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1024,
|
||||
depth=24,
|
||||
num_heads=16,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def vit_giant2(patch_size=16, **kwargs):
|
||||
"""
|
||||
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
||||
"""
|
||||
model = DinoVisionTransformer(
|
||||
patch_size=patch_size,
|
||||
embed_dim=1536,
|
||||
depth=40,
|
||||
num_heads=24,
|
||||
mlp_ratio=4,
|
||||
block_fn=partial(Block, attn_class=MemEffAttention),
|
||||
**kwargs,
|
||||
)
|
||||
return model
|
||||
195
modelscope/models/cv/anydoor/dinov2/hubconf.py
Normal file
195
modelscope/models/cv/anydoor/dinov2/hubconf.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
dependencies = ['torch']
|
||||
|
||||
_DINOV2_BASE_URL = 'https://dl.fbaipublicfiles.com/dinov2'
|
||||
|
||||
|
||||
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
|
||||
compact_arch_name = arch_name.replace('_', '')[:4]
|
||||
return f'dinov2_{compact_arch_name}{patch_size}'
|
||||
|
||||
|
||||
def _make_dinov2_model(
|
||||
*,
|
||||
arch_name: str = 'vit_large',
|
||||
img_size: int = 518,
|
||||
patch_size: int = 14,
|
||||
init_values: float = 1.0,
|
||||
ffn_layer: str = 'mlp',
|
||||
block_chunks: int = 0,
|
||||
pretrained: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
from .dinov2.models import vision_transformer as vits
|
||||
|
||||
_ = _make_dinov2_model_name(arch_name, patch_size)
|
||||
vit_kwargs = dict(
|
||||
img_size=img_size,
|
||||
patch_size=patch_size,
|
||||
init_values=init_values,
|
||||
ffn_layer=ffn_layer,
|
||||
block_chunks=block_chunks,
|
||||
)
|
||||
vit_kwargs.update(**kwargs)
|
||||
model = vits.__dict__[arch_name](**vit_kwargs)
|
||||
|
||||
# if pretrained:
|
||||
# state_dict = torch.load('')
|
||||
# model.load_state_dict(state_dict, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def dinov2_vits14(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name='vit_small', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitb14(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-B/14 model pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name='vit_base', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitl14(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name='vit_large', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitg14(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
||||
"""
|
||||
return _make_dinov2_model(
|
||||
arch_name='vit_giant2',
|
||||
ffn_layer='swiglufused',
|
||||
pretrained=pretrained,
|
||||
**kwargs)
|
||||
|
||||
|
||||
def _make_dinov2_linear_head(
|
||||
*,
|
||||
model_name: str = 'dinov2_vitl14',
|
||||
embed_dim: int = 1024,
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
assert layers in (1, 4), f'Unsupported number of layers: {layers}'
|
||||
linear_head = nn.Linear((1 + layers) * embed_dim, 1_000)
|
||||
|
||||
if pretrained:
|
||||
layers_str = str(layers) if layers == 4 else ''
|
||||
url = _DINOV2_BASE_URL + f'/{model_name}/{model_name}_linear{layers_str}_head.pth'
|
||||
state_dict = torch.hub.load_state_dict_from_url(
|
||||
url, map_location='cpu')
|
||||
linear_head.load_state_dict(state_dict, strict=False)
|
||||
|
||||
return linear_head
|
||||
|
||||
|
||||
class _LinearClassifierWrapper(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
backbone: nn.Module,
|
||||
linear_head: nn.Module,
|
||||
layers: int = 4):
|
||||
super().__init__()
|
||||
self.backbone = backbone
|
||||
self.linear_head = linear_head
|
||||
self.layers = layers
|
||||
|
||||
def forward(self, x):
|
||||
if self.layers == 1:
|
||||
x = self.backbone.forward_features(x)
|
||||
cls_token = x['x_norm_clstoken'].squeeze(0)
|
||||
patch_tokens = x['x_norm_patchtokens'].squeeze(0)
|
||||
linear_input = torch.cat([cls_token, patch_tokens.mean(0)])
|
||||
elif self.layers == 4:
|
||||
x = self.backbone.get_intermediate_layers(
|
||||
x, n=4, return_class_token=True)
|
||||
linear_input = torch.cat([
|
||||
x[0][1].squeeze(0), x[1][1].squeeze(0), x[2][1].squeeze(0),
|
||||
x[3][1].squeeze(0), x[3][0].squeeze(0).mean(0)
|
||||
])
|
||||
else:
|
||||
assert False, f'Unsupported number of layers: {self.layers}'
|
||||
return self.linear_head(linear_input)
|
||||
|
||||
|
||||
def _make_dinov2_linear_classifier(
|
||||
*,
|
||||
arch_name: str = 'vit_large',
|
||||
layers: int = 4,
|
||||
pretrained: bool = True,
|
||||
**kwargs,
|
||||
):
|
||||
backbone = _make_dinov2_model(
|
||||
arch_name=arch_name, pretrained=pretrained, **kwargs)
|
||||
|
||||
embed_dim = backbone.embed_dim
|
||||
patch_size = backbone.patch_size
|
||||
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
||||
linear_head = _make_dinov2_linear_head(
|
||||
model_name=model_name,
|
||||
embed_dim=embed_dim,
|
||||
layers=layers,
|
||||
pretrained=pretrained)
|
||||
|
||||
return _LinearClassifierWrapper(
|
||||
backbone=backbone, linear_head=linear_head, layers=layers)
|
||||
|
||||
|
||||
def dinov2_vits14_lc(*, layers: int = 4, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-S/14 backbone (optionally)
|
||||
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name='vit_small', layers=layers, pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitb14_lc(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-B/14 backbone (optionally)
|
||||
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name='vit_base', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitl14_lc(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-L/14 backbone (optionally)
|
||||
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name='vit_large', pretrained=pretrained, **kwargs)
|
||||
|
||||
|
||||
def dinov2_vitg14_lc(*, pretrained: bool = True, **kwargs):
|
||||
"""
|
||||
Linear classifier (1 or 4 layers) on top of a DINOv2 ViT-g/14 backbone (optionally)
|
||||
pretrained on the LVD-142M dataset and trained on ImageNet-1k.
|
||||
"""
|
||||
return _make_dinov2_linear_classifier(
|
||||
arch_name='vit_giant2',
|
||||
ffn_layer='swiglufused',
|
||||
pretrained=pretrained,
|
||||
**kwargs)
|
||||
1
modelscope/models/cv/anydoor/ldm/__init__.py
Normal file
1
modelscope/models/cv/anydoor/ldm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
0
modelscope/models/cv/anydoor/ldm/models/__init__.py
Normal file
0
modelscope/models/cv/anydoor/ldm/models/__init__.py
Normal file
274
modelscope/models/cv/anydoor/ldm/models/autoencoder.py
Normal file
274
modelscope/models/cv/anydoor/ldm/models/autoencoder.py
Normal file
@@ -0,0 +1,274 @@
|
||||
from contextlib import contextmanager
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ...ldm.modules.diffusionmodules.model import Decoder, Encoder
|
||||
from ...ldm.modules.distributions.distributions import \
|
||||
DiagonalGaussianDistribution
|
||||
from ...ldm.modules.ema import LitEma
|
||||
from ...ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key='image',
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
ema_decay=None,
|
||||
learn_logvar=False):
|
||||
super().__init__()
|
||||
self.learn_logvar = learn_logvar
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig['double_z']
|
||||
self.quant_conv = torch.nn.Conv2d(2 * ddconfig['z_channels'],
|
||||
2 * embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim,
|
||||
ddconfig['z_channels'], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels) == int
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
|
||||
self.use_ema = ema_decay is not None
|
||||
if self.use_ema:
|
||||
self.ema_decay = ema_decay
|
||||
assert 0. < ema_decay < 1.
|
||||
self.model_ema = LitEma(self, decay=ema_decay)
|
||||
print(f'Keeping EMAs of {len(list(self.model_ema.buffers()))}.')
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location='cpu')['state_dict']
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print('Deleting key {} from state_dict.'.format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f'Restored from {path}')
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f'{context}: Switched to EMA weights')
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f'{context}: Restored training weights')
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1,
|
||||
2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
self.log(
|
||||
'aeloss',
|
||||
aeloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_ae,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
optimizer_idx,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='train')
|
||||
|
||||
self.log(
|
||||
'discloss',
|
||||
discloss,
|
||||
prog_bar=True,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=True)
|
||||
self.log_dict(
|
||||
log_dict_disc,
|
||||
prog_bar=False,
|
||||
logger=True,
|
||||
on_step=True,
|
||||
on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
_ = self._validation_step(batch, batch_idx, postfix='_ema')
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, postfix=''):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + postfix)
|
||||
|
||||
discloss, log_dict_disc = self.loss(
|
||||
inputs,
|
||||
reconstructions,
|
||||
posterior,
|
||||
1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split='val' + postfix)
|
||||
|
||||
self.log(f'val{postfix}/rec_loss',
|
||||
log_dict_ae[f'val{postfix}/rec_loss'])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
ae_params_list = list(self.encoder.parameters()) + list(
|
||||
self.decoder.parameters()) + list(
|
||||
self.quant_conv.parameters()) + list(
|
||||
self.post_quant_conv.parameters())
|
||||
if self.learn_logvar:
|
||||
print(f'{self.__class__.__name__}: Learning logvar')
|
||||
ae_params_list.append(self.loss.logvar)
|
||||
opt_ae = torch.optim.Adam(ae_params_list, lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(
|
||||
self.loss.discriminator.parameters(), lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, log_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log['samples'] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log['reconstructions'] = xrec
|
||||
if log_ema or self.use_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, posterior_ema = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec_ema.shape[1] > 3
|
||||
xrec_ema = self.to_rgb(xrec_ema)
|
||||
log['samples_ema'] = self.decode(
|
||||
torch.randn_like(posterior_ema.sample()))
|
||||
log['reconstructions_ema'] = xrec_ema
|
||||
log['inputs'] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == 'segmentation'
|
||||
if not hasattr(self, 'colorize'):
|
||||
self.register_buffer('colorize',
|
||||
torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2. * (x - x.min()) / (x.max() - x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
446
modelscope/models/cv/anydoor/ldm/models/diffusion/ddim.py
Normal file
446
modelscope/models/cv/anydoor/ldm/models/diffusion/ddim.py
Normal file
@@ -0,0 +1,446 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ....ldm.modules.diffusionmodules.util import (
|
||||
extract_into_tensor, make_ddim_sampling_parameters, make_ddim_timesteps,
|
||||
noise_like)
|
||||
|
||||
|
||||
class DDIMSampler(object):
|
||||
|
||||
def __init__(self, model, schedule='linear', **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device('cuda'):
|
||||
attr = attr.to(torch.device('cuda'))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
|
||||
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None,
|
||||
**kwargs):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
ctmp = conditioning[list(conditioning.keys())[0]]
|
||||
while isinstance(ctmp, list):
|
||||
ctmp = ctmp[0]
|
||||
cbs = ctmp.shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
elif isinstance(conditioning, list):
|
||||
for ctmp in conditioning:
|
||||
if ctmp.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for DDIM sampling is {size}, eta {eta}')
|
||||
|
||||
samples, intermediates = self.ddim_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
ucg_schedule=ucg_schedule)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def ddim_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None,
|
||||
ucg_schedule=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = reversed(range(
|
||||
0, timesteps)) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
if ucg_schedule is not None:
|
||||
assert len(ucg_schedule) == len(time_range)
|
||||
unconditional_guidance_scale = ucg_schedule[i]
|
||||
|
||||
outs = self.p_sample_ddim(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0 = outs
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_ddim(self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
model_output = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
if isinstance(c, dict):
|
||||
assert isinstance(unconditional_conditioning, dict)
|
||||
c_in = dict()
|
||||
for k in c:
|
||||
if isinstance(c[k], list):
|
||||
c_in[k] = [
|
||||
torch.cat(
|
||||
[unconditional_conditioning[k][i], c[k][i]])
|
||||
for i in range(len(c[k]))
|
||||
]
|
||||
else:
|
||||
c_in[k] = torch.cat(
|
||||
[unconditional_conditioning[k], c[k]])
|
||||
elif isinstance(c, list):
|
||||
c_in = list()
|
||||
assert isinstance(unconditional_conditioning, list)
|
||||
for i in range(len(c)):
|
||||
c_in.append(
|
||||
torch.cat([unconditional_conditioning[i], c[i]]))
|
||||
else:
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
model_uncond, model_t = self.model.apply_model(x_in, t_in,
|
||||
c_in).chunk(2)
|
||||
model_output = model_uncond + unconditional_guidance_scale * (
|
||||
model_t - model_uncond)
|
||||
|
||||
if self.model.parameterization == 'v':
|
||||
e_t = self.model.predict_eps_from_z_and_v(x, t, model_output)
|
||||
else:
|
||||
e_t = model_output
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps', 'not implemented'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
|
||||
if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
if self.model.parameterization != 'v':
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
else:
|
||||
pred_x0 = self.model.predict_start_from_z_and_v(x, t, model_output)
|
||||
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
|
||||
if dynamic_threshold is not None:
|
||||
raise NotImplementedError()
|
||||
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self,
|
||||
x0,
|
||||
c,
|
||||
t_enc,
|
||||
use_original_steps=False,
|
||||
return_intermediates=None,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
callback=None):
|
||||
num_reference_steps = self.ddpm_num_timesteps if use_original_steps else self.ddim_timesteps.shape[
|
||||
0]
|
||||
|
||||
assert t_enc <= num_reference_steps
|
||||
num_steps = t_enc
|
||||
|
||||
if use_original_steps:
|
||||
alphas_next = self.alphas_cumprod[:num_steps]
|
||||
alphas = self.alphas_cumprod_prev[:num_steps]
|
||||
else:
|
||||
alphas_next = self.ddim_alphas[:num_steps]
|
||||
alphas = torch.tensor(self.ddim_alphas_prev[:num_steps])
|
||||
|
||||
x_next = x0
|
||||
intermediates = []
|
||||
inter_steps = []
|
||||
for i in tqdm(range(num_steps), desc='Encoding Image'):
|
||||
t = torch.full((x0.shape[0], ),
|
||||
i,
|
||||
device=self.model.device,
|
||||
dtype=torch.long)
|
||||
if unconditional_guidance_scale == 1.:
|
||||
noise_pred = self.model.apply_model(x_next, t, c)
|
||||
else:
|
||||
assert unconditional_conditioning is not None
|
||||
e_t_uncond, noise_pred = torch.chunk(
|
||||
self.model.apply_model(
|
||||
torch.cat((x_next, x_next)), torch.cat((t, t)),
|
||||
torch.cat((unconditional_conditioning, c))), 2)
|
||||
tmp = noise_pred - e_t_uncond
|
||||
noise_pred = e_t_uncond + unconditional_guidance_scale * tmp
|
||||
|
||||
xt_weighted = (alphas_next[i] / alphas[i]).sqrt() * x_next
|
||||
tmp = (1 / alphas_next[i] - 1).sqrt() - (1 / alphas[i] - 1).sqrt()
|
||||
weighted_noise_pred = alphas_next[i].sqrt() * tmp * noise_pred
|
||||
x_next = xt_weighted + weighted_noise_pred
|
||||
if return_intermediates and i % (num_steps // return_intermediates
|
||||
) == 0 and i < num_steps - 1:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
elif return_intermediates and i >= num_steps - 2:
|
||||
intermediates.append(x_next)
|
||||
inter_steps.append(i)
|
||||
if callback:
|
||||
callback(i)
|
||||
|
||||
out = {'x_encoded': x_next, 'intermediate_steps': inter_steps}
|
||||
if return_intermediates:
|
||||
out.update({'intermediates': intermediates})
|
||||
return x_next, out
|
||||
|
||||
@torch.no_grad()
|
||||
def stochastic_encode(self, x0, t, use_original_steps=False, noise=None):
|
||||
# fast, but does not allow for exact reconstruction
|
||||
# t serves as an index to gather the correct alphas
|
||||
if use_original_steps:
|
||||
sqrt_alphas_cumprod = self.sqrt_alphas_cumprod
|
||||
sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod
|
||||
else:
|
||||
sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas)
|
||||
sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas
|
||||
|
||||
if noise is None:
|
||||
noise = torch.randn_like(x0)
|
||||
return (
|
||||
extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0
|
||||
+ extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape)
|
||||
* noise)
|
||||
|
||||
@torch.no_grad()
|
||||
def decode(self,
|
||||
x_latent,
|
||||
cond,
|
||||
t_start,
|
||||
unconditional_guidance_scale=1.0,
|
||||
unconditional_conditioning=None,
|
||||
use_original_steps=False,
|
||||
callback=None):
|
||||
|
||||
timesteps = np.arange(self.ddpm_num_timesteps
|
||||
) if use_original_steps else self.ddim_timesteps
|
||||
timesteps = timesteps[:t_start]
|
||||
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
print(f'Running DDIM Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='Decoding image', total=total_steps)
|
||||
x_dec = x_latent
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((x_latent.shape[0], ),
|
||||
step,
|
||||
device=x_latent.device,
|
||||
dtype=torch.long)
|
||||
x_dec, _ = self.p_sample_ddim(
|
||||
x_dec,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=use_original_steps,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning)
|
||||
if callback:
|
||||
callback(i)
|
||||
return x_dec
|
||||
2295
modelscope/models/cv/anydoor/ldm/models/diffusion/ddpm.py
Normal file
2295
modelscope/models/cv/anydoor/ldm/models/diffusion/ddpm.py
Normal file
File diff suppressed because it is too large
Load Diff
328
modelscope/models/cv/anydoor/ldm/models/diffusion/plms.py
Normal file
328
modelscope/models/cv/anydoor/ldm/models/diffusion/plms.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""SAMPLING ONLY."""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from ....ldm.models.diffusion.sampling_util import norm_thresholding
|
||||
from ....ldm.modules.diffusionmodules.util import (
|
||||
make_ddim_sampling_parameters, make_ddim_timesteps, noise_like)
|
||||
|
||||
|
||||
class PLMSSampler(object):
|
||||
|
||||
def __init__(self, model, schedule='linear', **kwargs):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.schedule = schedule
|
||||
|
||||
def register_buffer(self, name, attr):
|
||||
if type(attr) == torch.Tensor:
|
||||
if attr.device != torch.device('cuda'):
|
||||
attr = attr.to(torch.device('cuda'))
|
||||
setattr(self, name, attr)
|
||||
|
||||
def make_schedule(self,
|
||||
ddim_num_steps,
|
||||
ddim_discretize='uniform',
|
||||
ddim_eta=0.,
|
||||
verbose=True):
|
||||
if ddim_eta != 0:
|
||||
raise ValueError('ddim_eta must be 0 for PLMS')
|
||||
self.ddim_timesteps = make_ddim_timesteps(
|
||||
ddim_discr_method=ddim_discretize,
|
||||
num_ddim_timesteps=ddim_num_steps,
|
||||
num_ddpm_timesteps=self.ddpm_num_timesteps,
|
||||
verbose=verbose)
|
||||
alphas_cumprod = self.model.alphas_cumprod
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
def to_torch(x):
|
||||
return x.clone().detach().to(torch.float32).to(self.model.device)
|
||||
|
||||
self.register_buffer('betas', to_torch(self.model.betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(self.model.alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
|
||||
|
||||
# ddim sampling parameters
|
||||
ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(
|
||||
alphacums=alphas_cumprod.cpu(),
|
||||
ddim_timesteps=self.ddim_timesteps,
|
||||
eta=ddim_eta,
|
||||
verbose=verbose)
|
||||
self.register_buffer('ddim_sigmas', ddim_sigmas)
|
||||
self.register_buffer('ddim_alphas', ddim_alphas)
|
||||
self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
|
||||
self.register_buffer('ddim_sqrt_one_minus_alphas',
|
||||
np.sqrt(1. - ddim_alphas))
|
||||
tmp1 = (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod)
|
||||
tmp2 = (1 - self.alphas_cumprod / self.alphas_cumprod_prev)
|
||||
sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(tmp1 * tmp2)
|
||||
self.register_buffer('ddim_sigmas_for_original_num_steps',
|
||||
sigmas_for_original_sampling_steps)
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(
|
||||
self,
|
||||
S,
|
||||
batch_size,
|
||||
shape,
|
||||
conditioning=None,
|
||||
callback=None,
|
||||
normals_sequence=None,
|
||||
img_callback=None,
|
||||
quantize_x0=False,
|
||||
eta=0.,
|
||||
mask=None,
|
||||
x0=None,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
verbose=True,
|
||||
x_T=None,
|
||||
log_every_t=100,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
||||
dynamic_threshold=None,
|
||||
**kwargs):
|
||||
if conditioning is not None:
|
||||
if isinstance(conditioning, dict):
|
||||
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
||||
if cbs != batch_size:
|
||||
print(
|
||||
f'Warning: Got {cbs} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
else:
|
||||
if conditioning.shape[0] != batch_size:
|
||||
print(
|
||||
f'Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}'
|
||||
)
|
||||
|
||||
self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
|
||||
# sampling
|
||||
C, H, W = shape
|
||||
size = (batch_size, C, H, W)
|
||||
print(f'Data shape for PLMS sampling is {size}')
|
||||
|
||||
samples, intermediates = self.plms_sampling(
|
||||
conditioning,
|
||||
size,
|
||||
callback=callback,
|
||||
img_callback=img_callback,
|
||||
quantize_denoised=quantize_x0,
|
||||
mask=mask,
|
||||
x0=x0,
|
||||
ddim_use_original_steps=False,
|
||||
noise_dropout=noise_dropout,
|
||||
temperature=temperature,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
x_T=x_T,
|
||||
log_every_t=log_every_t,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
dynamic_threshold=dynamic_threshold,
|
||||
)
|
||||
return samples, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def plms_sampling(self,
|
||||
cond,
|
||||
shape,
|
||||
x_T=None,
|
||||
ddim_use_original_steps=False,
|
||||
callback=None,
|
||||
timesteps=None,
|
||||
quantize_denoised=False,
|
||||
mask=None,
|
||||
x0=None,
|
||||
img_callback=None,
|
||||
log_every_t=100,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
dynamic_threshold=None):
|
||||
device = self.model.betas.device
|
||||
b = shape[0]
|
||||
if x_T is None:
|
||||
img = torch.randn(shape, device=device)
|
||||
else:
|
||||
img = x_T
|
||||
|
||||
if timesteps is None:
|
||||
timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
|
||||
elif timesteps is not None and not ddim_use_original_steps:
|
||||
subset_end = int(
|
||||
min(timesteps / self.ddim_timesteps.shape[0], 1)
|
||||
* self.ddim_timesteps.shape[0]) - 1
|
||||
timesteps = self.ddim_timesteps[:subset_end]
|
||||
|
||||
intermediates = {'x_inter': [img], 'pred_x0': [img]}
|
||||
time_range = list(reversed(range(
|
||||
0, timesteps))) if ddim_use_original_steps else np.flip(timesteps)
|
||||
total_steps = timesteps if ddim_use_original_steps else timesteps.shape[
|
||||
0]
|
||||
print(f'Running PLMS Sampling with {total_steps} timesteps')
|
||||
|
||||
iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps)
|
||||
old_eps = []
|
||||
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1
|
||||
ts = torch.full((b, ), step, device=device, dtype=torch.long)
|
||||
ts_next = torch.full((b, ),
|
||||
time_range[min(i + 1,
|
||||
len(time_range) - 1)],
|
||||
device=device,
|
||||
dtype=torch.long)
|
||||
|
||||
if mask is not None:
|
||||
assert x0 is not None
|
||||
img_orig = self.model.q_sample(
|
||||
x0, ts) # TODO: deterministic forward pass?
|
||||
img = img_orig * mask + (1. - mask) * img
|
||||
|
||||
outs = self.p_sample_plms(
|
||||
img,
|
||||
cond,
|
||||
ts,
|
||||
index=index,
|
||||
use_original_steps=ddim_use_original_steps,
|
||||
quantize_denoised=quantize_denoised,
|
||||
temperature=temperature,
|
||||
noise_dropout=noise_dropout,
|
||||
score_corrector=score_corrector,
|
||||
corrector_kwargs=corrector_kwargs,
|
||||
unconditional_guidance_scale=unconditional_guidance_scale,
|
||||
unconditional_conditioning=unconditional_conditioning,
|
||||
old_eps=old_eps,
|
||||
t_next=ts_next,
|
||||
dynamic_threshold=dynamic_threshold)
|
||||
img, pred_x0, e_t = outs
|
||||
old_eps.append(e_t)
|
||||
if len(old_eps) >= 4:
|
||||
old_eps.pop(0)
|
||||
if callback:
|
||||
callback(i)
|
||||
if img_callback:
|
||||
img_callback(pred_x0, i)
|
||||
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(img)
|
||||
intermediates['pred_x0'].append(pred_x0)
|
||||
|
||||
return img, intermediates
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_plms(self,
|
||||
x,
|
||||
c,
|
||||
t,
|
||||
index,
|
||||
repeat_noise=False,
|
||||
use_original_steps=False,
|
||||
quantize_denoised=False,
|
||||
temperature=1.,
|
||||
noise_dropout=0.,
|
||||
score_corrector=None,
|
||||
corrector_kwargs=None,
|
||||
unconditional_guidance_scale=1.,
|
||||
unconditional_conditioning=None,
|
||||
old_eps=None,
|
||||
t_next=None,
|
||||
dynamic_threshold=None):
|
||||
b, *_, device = *x.shape, x.device
|
||||
|
||||
def get_model_output(x, t):
|
||||
if unconditional_conditioning is None or unconditional_guidance_scale == 1.:
|
||||
e_t = self.model.apply_model(x, t, c)
|
||||
else:
|
||||
x_in = torch.cat([x] * 2)
|
||||
t_in = torch.cat([t] * 2)
|
||||
c_in = torch.cat([unconditional_conditioning, c])
|
||||
e_t_uncond, e_t = self.model.apply_model(x_in, t_in,
|
||||
c_in).chunk(2)
|
||||
e_t = e_t_uncond + unconditional_guidance_scale * (
|
||||
e_t - e_t_uncond)
|
||||
|
||||
if score_corrector is not None:
|
||||
assert self.model.parameterization == 'eps'
|
||||
e_t = score_corrector.modify_score(self.model, e_t, x, t, c,
|
||||
**corrector_kwargs)
|
||||
|
||||
return e_t
|
||||
|
||||
alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
|
||||
alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
|
||||
sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod \
|
||||
if use_original_steps else self.ddim_sqrt_one_minus_alphas
|
||||
sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
|
||||
|
||||
def get_x_prev_and_pred_x0(e_t, index):
|
||||
# select parameters corresponding to the currently considered timestep
|
||||
a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
|
||||
a_prev = torch.full((b, 1, 1, 1),
|
||||
alphas_prev[index],
|
||||
device=device)
|
||||
sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
|
||||
sqrt_one_minus_at = torch.full((b, 1, 1, 1),
|
||||
sqrt_one_minus_alphas[index],
|
||||
device=device)
|
||||
|
||||
# current prediction for x_0
|
||||
pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
|
||||
if quantize_denoised:
|
||||
pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
|
||||
if dynamic_threshold is not None:
|
||||
pred_x0 = norm_thresholding(pred_x0, dynamic_threshold)
|
||||
# direction pointing to x_t
|
||||
dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
|
||||
noise = sigma_t * noise_like(x.shape, device,
|
||||
repeat_noise) * temperature
|
||||
if noise_dropout > 0.:
|
||||
noise = torch.nn.functional.dropout(noise, p=noise_dropout)
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
|
||||
return x_prev, pred_x0
|
||||
|
||||
e_t = get_model_output(x, t)
|
||||
if len(old_eps) == 0:
|
||||
# Pseudo Improved Euler (2nd order)
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index)
|
||||
e_t_next = get_model_output(x_prev, t_next)
|
||||
e_t_prime = (e_t + e_t_next) / 2
|
||||
elif len(old_eps) == 1:
|
||||
# 2nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (3 * e_t - old_eps[-1]) / 2
|
||||
elif len(old_eps) == 2:
|
||||
# 3nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12
|
||||
elif len(old_eps) >= 3:
|
||||
# 4nd order Pseudo Linear Multistep (Adams-Bashforth)
|
||||
e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2]
|
||||
- 9 * old_eps[-3]) / 24
|
||||
|
||||
x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index)
|
||||
|
||||
return x_prev, pred_x0, e_t
|
||||
@@ -0,0 +1,25 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def append_dims(x, target_dims):
|
||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.
|
||||
From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py"""
|
||||
dims_to_append = target_dims - x.ndim
|
||||
if dims_to_append < 0:
|
||||
raise ValueError(
|
||||
f'input has {x.ndim} dims but target_dims is {target_dims}, which is less'
|
||||
)
|
||||
return x[(..., ) + (None, ) * dims_to_append]
|
||||
|
||||
|
||||
def norm_thresholding(x0, value):
|
||||
s = append_dims(
|
||||
x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim)
|
||||
return x0 * (value / s)
|
||||
|
||||
|
||||
def spatial_norm_thresholding(x0, value):
|
||||
# b c h w
|
||||
s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value)
|
||||
return x0 * (value / s)
|
||||
367
modelscope/models/cv/anydoor/ldm/modules/attention.py
Normal file
367
modelscope/models/cv/anydoor/ldm/modules/attention.py
Normal file
@@ -0,0 +1,367 @@
|
||||
import math
|
||||
# CrossAttn precision handling
|
||||
import os
|
||||
from inspect import isfunction
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from torch import einsum, nn
|
||||
|
||||
from ...ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except Exception:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
|
||||
_ATTN_PRECISION = os.environ.get('ATTN_PRECISION', 'fp32')
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return {el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(nn.Linear(
|
||||
dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(project_in, nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h),
|
||||
(q, k, v))
|
||||
|
||||
# force cast to fp32 to avoid overflowing
|
||||
if _ATTN_PRECISION == 'fp32':
|
||||
with torch.autocast(enabled=False, device_type='cuda'):
|
||||
q, k = q.float(), k.float()
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
else:
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
del q, k
|
||||
|
||||
if exists(mask):
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
sim = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', sim, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttention(nn.Module):
|
||||
# https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
|
||||
def __init__(self,
|
||||
query_dim,
|
||||
context_dim=None,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
print(
|
||||
f'Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using '
|
||||
f'{heads} heads.')
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim), nn.Dropout(dropout))
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
b, _, _ = q.shape
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(b, t.shape[
|
||||
1], self.heads, self.dim_head).permute(0, 2, 1, 3).reshape(
|
||||
b * self.heads, t.shape[1], self.dim_head).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
|
||||
# actually compute the attention, what we cannot get enough of
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
if exists(mask):
|
||||
raise NotImplementedError
|
||||
out = (
|
||||
out.unsqueeze(0).reshape(
|
||||
b, self.heads, out.shape[1],
|
||||
self.dim_head).permute(0, 2, 1,
|
||||
3).reshape(b, out.shape[1],
|
||||
self.heads * self.dim_head))
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
ATTENTION_MODES = {
|
||||
'softmax': CrossAttention, # vanilla attention
|
||||
'softmax-xformers': MemoryEfficientCrossAttention
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
gated_ff=True,
|
||||
checkpoint=True,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
attn_mode = 'softmax-xformers' if XFORMERS_IS_AVAILBLE else 'softmax'
|
||||
assert attn_mode in self.ATTENTION_MODES
|
||||
attn_cls = self.ATTENTION_MODES[attn_mode]
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = attn_cls(
|
||||
query_dim=dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else
|
||||
None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = attn_cls(
|
||||
query_dim=dim,
|
||||
context_dim=context_dim,
|
||||
heads=n_heads,
|
||||
dim_head=d_head,
|
||||
dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(),
|
||||
self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(
|
||||
self.norm1(x),
|
||||
context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
NEW: use_linear for more efficiency instead of the 1x1 convs
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
n_heads,
|
||||
d_head,
|
||||
depth=1,
|
||||
dropout=0.,
|
||||
context_dim=None,
|
||||
disable_self_attn=False,
|
||||
use_linear=False,
|
||||
use_checkpoint=True):
|
||||
super().__init__()
|
||||
if exists(context_dim) and not isinstance(context_dim, list):
|
||||
context_dim = [context_dim]
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
if not use_linear:
|
||||
self.proj_in = nn.Conv2d(
|
||||
in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.proj_in = nn.Linear(in_channels, inner_dim)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList([
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
n_heads,
|
||||
d_head,
|
||||
dropout=dropout,
|
||||
context_dim=context_dim[d],
|
||||
disable_self_attn=disable_self_attn,
|
||||
checkpoint=use_checkpoint) for d in range(depth)
|
||||
])
|
||||
if not use_linear:
|
||||
self.proj_out = zero_module(
|
||||
nn.Conv2d(
|
||||
inner_dim, in_channels, kernel_size=1, stride=1,
|
||||
padding=0))
|
||||
else:
|
||||
self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
|
||||
self.use_linear = use_linear
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
if not isinstance(context, list):
|
||||
context = [context]
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
if not self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
if self.use_linear:
|
||||
x = self.proj_in(x)
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
x = block(x, context=context[i])
|
||||
if self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
if not self.use_linear:
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
@@ -0,0 +1,966 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
from typing import Any, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import rearrange
|
||||
|
||||
from ....ldm.modules.attention import MemoryEfficientCrossAttention
|
||||
|
||||
try:
|
||||
import xformers
|
||||
import xformers.ops
|
||||
XFORMERS_IS_AVAILBLE = True
|
||||
except Exception:
|
||||
XFORMERS_IS_AVAILBLE = False
|
||||
print("No module 'xformers'. Proceeding without it.")
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(
|
||||
num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(
|
||||
x, scale_factor=2.0, mode='nearest')
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode='constant', value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
in_channels,
|
||||
out_channels=None,
|
||||
conv_shortcut=False,
|
||||
dropout,
|
||||
temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h * w)
|
||||
q = q.permute(0, 2, 1) # b,hw,c
|
||||
k = k.reshape(b, c, h * w) # b,c,hw
|
||||
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h * w)
|
||||
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(
|
||||
v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x + h_
|
||||
|
||||
|
||||
class MemoryEfficientAttnBlock(nn.Module):
|
||||
"""
|
||||
Uses xformers efficient implementation,
|
||||
Note: this is a single-head self-attention operation
|
||||
"""
|
||||
|
||||
#
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.attention_op: Optional[Any] = None
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
B, C, H, W = q.shape
|
||||
q, k, v = map(lambda x: rearrange(x, 'b c h w -> b (h w) c'),
|
||||
(q, k, v))
|
||||
|
||||
q, k, v = map(
|
||||
lambda t: t.unsqueeze(3).reshape(B, t.shape[1], 1, C).permute(
|
||||
0, 2, 1, 3).reshape(B * 1, t.shape[1], C).contiguous(),
|
||||
(q, k, v),
|
||||
)
|
||||
out = xformers.ops.memory_efficient_attention(
|
||||
q, k, v, attn_bias=None, op=self.attention_op)
|
||||
|
||||
out = (
|
||||
out.unsqueeze(0).reshape(B, 1, out.shape[1],
|
||||
C).permute(0, 2, 1,
|
||||
3).reshape(B, out.shape[1], C))
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', b=B, h=H, w=W, c=C)
|
||||
out = self.proj_out(out)
|
||||
return x + out
|
||||
|
||||
|
||||
class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention):
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
b, c, h, w = x.shape
|
||||
x = rearrange(x, 'b c h w -> b (h w) c')
|
||||
out = super().forward(x, context=context, mask=mask)
|
||||
out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c)
|
||||
return x + out
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type='vanilla', attn_kwargs=None):
|
||||
assert attn_type in [
|
||||
'vanilla', 'vanilla-xformers', 'memory-efficient-cross-attn', 'linear',
|
||||
'none'
|
||||
], f'attn_type {attn_type} unknown'
|
||||
if XFORMERS_IS_AVAILBLE and attn_type == 'vanilla':
|
||||
attn_type = 'vanilla-xformers'
|
||||
print(
|
||||
f"making attention of type '{attn_type}' with {in_channels} in_channels"
|
||||
)
|
||||
if attn_type == 'vanilla':
|
||||
assert attn_kwargs is None
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == 'vanilla-xformers':
|
||||
print(
|
||||
f'building MemoryEfficientAttnBlock with {in_channels} in_channels...'
|
||||
)
|
||||
return MemoryEfficientAttnBlock(in_channels)
|
||||
elif type == 'memory-efficient-cross-attn':
|
||||
attn_kwargs['query_dim'] = in_channels
|
||||
return MemoryEfficientCrossAttentionWrapper(**attn_kwargs)
|
||||
elif attn_type == 'none':
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
use_timestep=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla'):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch * 4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch, self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch, self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
skip_in = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch * in_ch_mult[i_level]
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in + skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
# assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](torch.cat([h, hs.pop()],
|
||||
dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
double_z=True,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1, ) + tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch * in_ch_mult[i_level]
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions - 1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in,
|
||||
2 * z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
ch,
|
||||
out_ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
in_channels,
|
||||
resolution,
|
||||
z_channels,
|
||||
give_pre_end=False,
|
||||
tanh_out=False,
|
||||
use_linear_attn=False,
|
||||
attn_type='vanilla',
|
||||
**ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn:
|
||||
attn_type = 'linear'
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
block_in = ch * ch_mult[self.num_resolutions - 1]
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.z_shape = (1, z_channels, curr_res, curr_res)
|
||||
print('Working with z of shape {} = {} dimensions.'.format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_ch, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
# assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([
|
||||
nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(
|
||||
in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
ResnetBlock(
|
||||
in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
ResnetBlock(
|
||||
in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0),
|
||||
nn.Conv2d(2 * in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)
|
||||
])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1, 2, 3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
ch,
|
||||
num_res_blocks,
|
||||
resolution,
|
||||
ch_mult=(2, 2),
|
||||
dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2**(self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(
|
||||
ResnetBlock(
|
||||
in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(
|
||||
block_in, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
factor,
|
||||
in_channels,
|
||||
mid_channels,
|
||||
out_channels,
|
||||
depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.res_block1 = nn.ModuleList([
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)
|
||||
])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([
|
||||
ResnetBlock(
|
||||
in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)
|
||||
])
|
||||
|
||||
self.conv_out = nn.Conv2d(
|
||||
mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=(int(round(x.shape[2] * self.factor)),
|
||||
int(round(x.shape[3] * self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
ch,
|
||||
resolution,
|
||||
out_ch,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(
|
||||
in_channels=in_channels,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch=ch,
|
||||
ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn,
|
||||
double_z=False,
|
||||
resolution=resolution,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn,
|
||||
out_channels=out_ch,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
z_channels,
|
||||
out_ch,
|
||||
resolution,
|
||||
num_res_blocks,
|
||||
attn_resolutions,
|
||||
ch,
|
||||
ch_mult=(1, 2, 4, 8),
|
||||
dropout=0.0,
|
||||
resamp_with_conv=True,
|
||||
rescale_factor=1.0,
|
||||
rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels * ch_mult[-1]
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_ch,
|
||||
z_channels=tmp_chn,
|
||||
attn_resolutions=attn_resolutions,
|
||||
dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv,
|
||||
in_channels=None,
|
||||
num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult,
|
||||
resolution=resolution,
|
||||
ch=ch)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=rescale_factor,
|
||||
in_channels=z_channels,
|
||||
mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn,
|
||||
depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_size,
|
||||
out_size,
|
||||
in_channels,
|
||||
out_channels,
|
||||
ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size // in_size)) + 1
|
||||
factor_up = 1. + (out_size % in_size)
|
||||
print(
|
||||
f'Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}'
|
||||
)
|
||||
self.rescaler = LatentRescaler(
|
||||
factor=factor_up,
|
||||
in_channels=in_channels,
|
||||
mid_channels=2 * in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(
|
||||
out_ch=out_channels,
|
||||
resolution=out_size,
|
||||
z_channels=in_channels,
|
||||
num_res_blocks=2,
|
||||
attn_resolutions=[],
|
||||
in_channels=None,
|
||||
ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
|
||||
def __init__(self, in_channels=None, learned=False, mode='bilinear'):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(
|
||||
f'Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode'
|
||||
)
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor == 1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(
|
||||
x,
|
||||
mode=self.mode,
|
||||
align_corners=False,
|
||||
scale_factor=scale_factor)
|
||||
return x
|
||||
@@ -0,0 +1,820 @@
|
||||
import math
|
||||
from abc import abstractmethod
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ....ldm.modules.attention import SpatialTransformer
|
||||
from ....ldm.modules.diffusionmodules.util import (avg_pool_nd, checkpoint,
|
||||
conv_nd, linear,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
zero_module)
|
||||
from ....ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
# go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(
|
||||
th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2),
|
||||
mode='nearest')
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode='nearest')
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(
|
||||
self.channels, self.out_channels, kernel_size=ks, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
use_conv,
|
||||
dims=2,
|
||||
out_channels=None,
|
||||
padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims,
|
||||
self.channels,
|
||||
self.out_channels,
|
||||
3,
|
||||
stride=stride,
|
||||
padding=padding)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels
|
||||
if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(
|
||||
dims, self.out_channels, self.out_channels, 3, padding=1)),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels,
|
||||
1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(self._forward, (x, emb), self.parameters(),
|
||||
self.use_checkpoint)
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f'q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}'
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(
|
||||
self._forward, (x, ), self.parameters(), True
|
||||
) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
# return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial**2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(
|
||||
ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
'bct,bcs->bts', q * scale,
|
||||
k * scale) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum('bts,bcs->bct', weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
'bct,bcs->bts',
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum('bts,bcs->bct', weight,
|
||||
v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None,
|
||||
disable_middle_self_attn=False,
|
||||
use_linear_in_transformer=False,
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError(
|
||||
'provide num_res_blocks either as an int (globally constant) or '
|
||||
'as a list/tuple (per-level) with the same length as channel_mult'
|
||||
)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(
|
||||
map(
|
||||
lambda i: self.num_res_blocks[i] >= num_attention_blocks[i
|
||||
],
|
||||
range(len(num_attention_blocks))))
|
||||
print(
|
||||
f'Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. '
|
||||
f'This option has LESS priority than attention_resolutions {attention_resolutions}, '
|
||||
f'i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, '
|
||||
f'attention will still not be set.')
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
if isinstance(self.num_classes, int):
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
elif self.num_classes == 'continuous':
|
||||
print('setting up linear c_adm embedding layer')
|
||||
self.label_emb = nn.Linear(1, time_embed_dim)
|
||||
else:
|
||||
raise ValueError()
|
||||
|
||||
self.input_blocks = nn.ModuleList([
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1))
|
||||
])
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks
|
||||
) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint))
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
) if resblock_updown else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer( # always uses a self-attn
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disable_middle_self_attn,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
# num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks
|
||||
) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else
|
||||
SpatialTransformer(
|
||||
ch,
|
||||
num_heads,
|
||||
dim_head,
|
||||
depth=transformer_depth,
|
||||
context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa,
|
||||
use_linear=use_linear_in_transformer,
|
||||
use_checkpoint=use_checkpoint))
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
) if resblock_updown else Upsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch))
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(
|
||||
conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
# nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), 'must specify y if and only if the model is class-conditional'
|
||||
hs = []
|
||||
t_emb = timestep_embedding(
|
||||
timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
@@ -0,0 +1,103 @@
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ....ldm.modules.diffusionmodules.util import (extract_into_tensor,
|
||||
make_beta_schedule)
|
||||
from ....ldm.util import default
|
||||
|
||||
|
||||
class AbstractLowScaleModel(nn.Module):
|
||||
# for concatenating a downsampled image to the latent representation
|
||||
def __init__(self, noise_schedule_config=None):
|
||||
super(AbstractLowScaleModel, self).__init__()
|
||||
if noise_schedule_config is not None:
|
||||
self.register_schedule(**noise_schedule_config)
|
||||
|
||||
def register_schedule(self,
|
||||
beta_schedule='linear',
|
||||
timesteps=1000,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
betas = make_beta_schedule(
|
||||
beta_schedule,
|
||||
timesteps,
|
||||
linear_start=linear_start,
|
||||
linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[
|
||||
0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev',
|
||||
to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod',
|
||||
to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod',
|
||||
to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod',
|
||||
to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape)
|
||||
* x_start
|
||||
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t,
|
||||
x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
return x, None
|
||||
|
||||
def decode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class SimpleImageConcat(AbstractLowScaleModel):
|
||||
# no noise level conditioning
|
||||
def __init__(self):
|
||||
super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
|
||||
self.max_noise_level = 0
|
||||
|
||||
def forward(self, x):
|
||||
# fix to constant noise level
|
||||
return x, torch.zeros(x.shape[0], device=x.device).long()
|
||||
|
||||
|
||||
class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
|
||||
|
||||
def __init__(self,
|
||||
noise_schedule_config,
|
||||
max_noise_level=1000,
|
||||
to_cuda=False):
|
||||
super().__init__(noise_schedule_config=noise_schedule_config)
|
||||
self.max_noise_level = max_noise_level
|
||||
|
||||
def forward(self, x, noise_level=None):
|
||||
if noise_level is None:
|
||||
noise_level = torch.randint(
|
||||
0, self.max_noise_level, (x.shape[0], ),
|
||||
device=x.device).long()
|
||||
else:
|
||||
assert isinstance(noise_level, torch.Tensor)
|
||||
z = self.q_sample(x, noise_level)
|
||||
return z, noise_level
|
||||
@@ -0,0 +1,310 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from einops import repeat
|
||||
|
||||
from ....ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule,
|
||||
n_timestep,
|
||||
linear_start=1e-4,
|
||||
linear_end=2e-2,
|
||||
cosine_s=8e-3):
|
||||
if schedule == 'linear':
|
||||
betas = (
|
||||
torch.linspace(
|
||||
linear_start**0.5,
|
||||
linear_end**0.5,
|
||||
n_timestep,
|
||||
dtype=torch.float64)**2)
|
||||
|
||||
elif schedule == 'cosine':
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep
|
||||
+ cosine_s)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == 'sqrt_linear':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == 'sqrt':
|
||||
betas = torch.linspace(
|
||||
linear_start, linear_end, n_timestep, dtype=torch.float64)**0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method,
|
||||
num_ddim_timesteps,
|
||||
num_ddpm_timesteps,
|
||||
verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8),
|
||||
num_ddim_timesteps))**2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f'There is no ddim discretization method called "{ddim_discr_method}"'
|
||||
)
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums,
|
||||
ddim_timesteps,
|
||||
eta,
|
||||
verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]]
|
||||
+ alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
tmp = (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)
|
||||
sigmas = eta * np.sqrt(tmp)
|
||||
if verbose:
|
||||
print(
|
||||
f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}'
|
||||
)
|
||||
print(
|
||||
f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}'
|
||||
)
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
ctx.gpu_autocast_kwargs = {
|
||||
'enabled': torch.is_autocast_enabled(),
|
||||
'dtype': torch.get_autocast_gpu_dtype(),
|
||||
'cache_enabled': torch.is_autocast_cache_enabled()
|
||||
}
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [
|
||||
x.detach().requires_grad_(True) for x in ctx.input_tensors
|
||||
]
|
||||
with torch.enable_grad(), \
|
||||
torch.cuda.amp.autocast(**ctx.gpu_autocast_kwargs):
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period)
|
||||
* torch.arange(start=0, end=half, dtype=torch.float32)
|
||||
/ half).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat(
|
||||
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f'unsupported dimensions: {dims}')
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(
|
||||
c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
|
||||
def repeat_noise():
|
||||
torch.randn(
|
||||
(1, *shape[1:]), device=device).repeat(shape[0],
|
||||
*((1, ) * (len(shape) - 1)))
|
||||
|
||||
noise = lambda: torch.randn(shape, device=device) # noqa
|
||||
return repeat_noise() if repeat else noise()
|
||||
@@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(
|
||||
self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(
|
||||
self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1, 2, 3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar
|
||||
+ torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, 'at least one argument must be a Tensor'
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
tmp = ((mean1 - mean2)**2) * torch.exp(-logvar2)
|
||||
return 0.5 * (-1.0 + logvar2 - logvar1 + torch.exp(logvar1 - logvar2)
|
||||
+ tmp)
|
||||
87
modelscope/models/cv/anydoor/ldm/modules/ema.py
Normal file
87
modelscope/models/cv/anydoor/ldm/modules/ema.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LitEma(nn.Module):
|
||||
|
||||
def __init__(self, model, decay=0.9999, use_num_upates=True):
|
||||
super().__init__()
|
||||
if decay < 0.0 or decay > 1.0:
|
||||
raise ValueError('Decay must be between 0 and 1')
|
||||
|
||||
self.m_name2s_name = {}
|
||||
self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
|
||||
self.register_buffer(
|
||||
'num_updates',
|
||||
torch.tensor(0, dtype=torch.int)
|
||||
if use_num_upates else torch.tensor(-1, dtype=torch.int))
|
||||
|
||||
for name, p in model.named_parameters():
|
||||
if p.requires_grad:
|
||||
# remove as '.'-character is not allowed in buffers
|
||||
s_name = name.replace('.', '')
|
||||
self.m_name2s_name.update({name: s_name})
|
||||
self.register_buffer(s_name, p.clone().detach().data)
|
||||
|
||||
self.collected_params = []
|
||||
|
||||
def reset_num_updates(self):
|
||||
del self.num_updates
|
||||
self.register_buffer('num_updates', torch.tensor(0, dtype=torch.int))
|
||||
|
||||
def forward(self, model):
|
||||
decay = self.decay
|
||||
|
||||
if self.num_updates >= 0:
|
||||
self.num_updates += 1
|
||||
tmp = (1 + self.num_updates) / (10 + self.num_updates)
|
||||
decay = min(self.decay, tmp)
|
||||
|
||||
one_minus_decay = 1.0 - decay
|
||||
|
||||
with torch.no_grad():
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
sname = self.m_name2s_name[key]
|
||||
shadow_params[sname] = shadow_params[sname].type_as(
|
||||
m_param[key])
|
||||
tmp = shadow_params[sname] - m_param[key]
|
||||
shadow_params[sname].sub_(one_minus_decay * tmp)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def copy_to(self, model):
|
||||
m_param = dict(model.named_parameters())
|
||||
shadow_params = dict(self.named_buffers())
|
||||
for key in m_param:
|
||||
if m_param[key].requires_grad:
|
||||
m_param[key].data.copy_(
|
||||
shadow_params[self.m_name2s_name[key]].data)
|
||||
else:
|
||||
assert key not in self.m_name2s_name
|
||||
|
||||
def store(self, parameters):
|
||||
"""
|
||||
Save the current parameters for restoring later.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
temporarily stored.
|
||||
"""
|
||||
self.collected_params = [param.clone() for param in parameters]
|
||||
|
||||
def restore(self, parameters):
|
||||
"""
|
||||
Restore the parameters stored with the `store` method.
|
||||
Useful to validate the model with EMA parameters without affecting the
|
||||
original optimization process. Store the parameters before the
|
||||
`copy_to` method. After validation (or model saving), use this to
|
||||
restore the former parameters.
|
||||
Args:
|
||||
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
||||
updated with the stored parameters.
|
||||
"""
|
||||
for c_param, param in zip(self.collected_params, parameters):
|
||||
param.data.copy_(c_param.data)
|
||||
371
modelscope/models/cv/anydoor/ldm/modules/encoders/modules.py
Normal file
371
modelscope/models/cv/anydoor/ldm/modules/encoders/modules.py
Normal file
@@ -0,0 +1,371 @@
|
||||
import os
|
||||
|
||||
import open_clip
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
from transformers import (CLIPTextModel, CLIPTokenizer, T5EncoderModel,
|
||||
T5Tokenizer)
|
||||
|
||||
from ....dinov2 import hubconf
|
||||
from ....ldm.util import count_params
|
||||
|
||||
|
||||
class LayerNormFp32(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16 (by casting to float32 and back)."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
x = F.layer_norm(
|
||||
x.to(torch.float32), self.normalized_shape, self.weight, self.bias,
|
||||
self.eps)
|
||||
return x.to(orig_type)
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm (with cast back to input dtype)."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias,
|
||||
self.eps)
|
||||
return x.to(orig_type)
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
self.n_classes = n_classes
|
||||
self.ucg_rate = ucg_rate
|
||||
|
||||
def forward(self, batch, key=None, disable_dropout=False):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
if self.ucg_rate > 0. and not disable_dropout:
|
||||
mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
|
||||
c = mask * c + (1 - mask) * torch.ones_like(c) * (
|
||||
self.n_classes - 1)
|
||||
c = c.long()
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
def get_unconditional_conditioning(self, bs, device='cuda'):
|
||||
uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
|
||||
uc = torch.ones((bs, ), device=device) * uc_class
|
||||
uc = {self.key: uc}
|
||||
return uc
|
||||
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
|
||||
def __init__(self,
|
||||
version='google/t5-v1_1-large',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True
|
||||
): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version)
|
||||
self.transformer = T5EncoderModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
if freeze:
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
LAYERS = ['last', 'pooled', 'hidden']
|
||||
|
||||
def __init__(self,
|
||||
version='openai/clip-vit-large-patch14',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last',
|
||||
layer_idx=None): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version)
|
||||
self.transformer = CLIPTextModel.from_pretrained(version)
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
self.layer_idx = layer_idx
|
||||
if layer == 'hidden':
|
||||
assert layer_idx is not None
|
||||
assert 0 <= abs(layer_idx) <= 12
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
# self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(
|
||||
text,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_length=True,
|
||||
return_overflowing_tokens=False,
|
||||
padding='max_length',
|
||||
return_tensors='pt')
|
||||
tokens = batch_encoding['input_ids'].to(self.device)
|
||||
outputs = self.transformer(
|
||||
input_ids=tokens, output_hidden_states=self.layer == 'hidden')
|
||||
if self.layer == 'last':
|
||||
z = outputs.last_hidden_state
|
||||
elif self.layer == 'pooled':
|
||||
z = outputs.pooler_output[:, None, :]
|
||||
else:
|
||||
z = outputs.hidden_states[self.layer_idx]
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenOpenCLIPEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for text
|
||||
"""
|
||||
LAYERS = [
|
||||
# "pooled",
|
||||
'last',
|
||||
'penultimate'
|
||||
]
|
||||
|
||||
def __init__(self,
|
||||
arch='ViT-H-14',
|
||||
version='laion2b_s32b_b79k',
|
||||
device='cuda',
|
||||
max_length=77,
|
||||
freeze=True,
|
||||
layer='last'):
|
||||
super().__init__()
|
||||
assert layer in self.LAYERS
|
||||
model, _, _ = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.visual
|
||||
self.model = model
|
||||
|
||||
self.device = device
|
||||
self.max_length = max_length
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.layer = layer
|
||||
if self.layer == 'last':
|
||||
self.layer_idx = 0
|
||||
elif self.layer == 'penultimate':
|
||||
self.layer_idx = 1
|
||||
else:
|
||||
raise NotImplementedError()
|
||||
|
||||
def freeze(self):
|
||||
self.model = self.model.eval()
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
tokens = open_clip.tokenize(text)
|
||||
z = self.encode_with_transformer(tokens.to(self.device))
|
||||
return z
|
||||
|
||||
def encode_with_transformer(self, text):
|
||||
x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
|
||||
x = x + self.model.positional_embedding
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.model.ln_final(x)
|
||||
return x
|
||||
|
||||
def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
|
||||
for i, r in enumerate(self.model.transformer.resblocks):
|
||||
if i == len(self.model.transformer.resblocks) - self.layer_idx:
|
||||
break
|
||||
if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(
|
||||
):
|
||||
x = checkpoint(r, x, attn_mask)
|
||||
else:
|
||||
x = r(x, attn_mask=attn_mask)
|
||||
return x
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
|
||||
class FrozenCLIPT5Encoder(AbstractEncoder):
|
||||
|
||||
def __init__(self,
|
||||
clip_version='openai/clip-vit-large-patch14',
|
||||
t5_version='google/t5-v1_1-xl',
|
||||
device='cuda',
|
||||
clip_max_length=77,
|
||||
t5_max_length=77):
|
||||
super().__init__()
|
||||
self.clip_encoder = FrozenCLIPEmbedder(
|
||||
clip_version, device, max_length=clip_max_length)
|
||||
self.t5_encoder = FrozenT5Embedder(
|
||||
t5_version, device, max_length=t5_max_length)
|
||||
print(
|
||||
f'{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, '
|
||||
f'{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.'
|
||||
)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
def forward(self, text):
|
||||
clip_z = self.clip_encoder.encode(text)
|
||||
t5_z = self.t5_encoder.encode(text)
|
||||
return [clip_z, t5_z]
|
||||
|
||||
|
||||
class FrozenOpenCLIPImageEncoder(AbstractEncoder):
|
||||
"""
|
||||
Uses the OpenCLIP transformer encoder for image
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
arch='ViT-H-14',
|
||||
version='laion2b_s32b_b79k',
|
||||
device='cuda',
|
||||
freeze=True):
|
||||
super().__init__()
|
||||
model, _, preprocess = open_clip.create_model_and_transforms(
|
||||
arch, device=torch.device('cpu'), pretrained=version)
|
||||
del model.transformer
|
||||
self.model = model
|
||||
self.model.visual.output_tokens = True
|
||||
self.device = device
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.image_mean = torch.tensor(
|
||||
[0.48145466, 0.4578275,
|
||||
0.40821073]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
self.image_std = torch.tensor(
|
||||
[0.26862954, 0.26130258,
|
||||
0.275777]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
self.projector_token = nn.Linear(1280, 1024)
|
||||
self.projector_embed = nn.Linear(1024, 1024)
|
||||
|
||||
def freeze(self):
|
||||
self.model.visual.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image):
|
||||
if isinstance(image, list):
|
||||
image = torch.cat(image, 0)
|
||||
image = (image.to(self.device) - self.image_mean.to(
|
||||
self.device)) / self.image_std.to(self.device)
|
||||
image_features, tokens = self.model.visual(image)
|
||||
image_features = image_features.unsqueeze(1)
|
||||
image_features = self.projector_embed(image_features)
|
||||
tokens = self.projector_token(tokens)
|
||||
hint = torch.cat([image_features, tokens], 1)
|
||||
return hint
|
||||
|
||||
def encode(self, image):
|
||||
return self(image)
|
||||
|
||||
|
||||
class FrozenDinoV2Encoder(AbstractEncoder):
|
||||
"""
|
||||
Uses the DINOv2 encoder for image
|
||||
"""
|
||||
|
||||
def __init__(self, model_path, device='cuda', freeze=True):
|
||||
DINOv2_weight_path = model_path
|
||||
|
||||
super().__init__()
|
||||
dinov2 = hubconf.dinov2_vitg14()
|
||||
state_dict = torch.load(DINOv2_weight_path)
|
||||
dinov2.load_state_dict(state_dict, strict=False)
|
||||
self.model = dinov2.to(device)
|
||||
self.device = device
|
||||
if freeze:
|
||||
self.freeze()
|
||||
self.image_mean = torch.tensor(
|
||||
[0.485, 0.456, 0.406]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
self.image_std = torch.tensor(
|
||||
[0.229, 0.224, 0.225]).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
|
||||
self.projector = nn.Linear(1536, 1024)
|
||||
|
||||
def freeze(self):
|
||||
self.model.eval()
|
||||
for param in self.model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, image):
|
||||
if isinstance(image, list):
|
||||
image = torch.cat(image, 0)
|
||||
|
||||
image = (image.to(self.device) - self.image_mean.to(
|
||||
self.device)) / self.image_std.to(self.device)
|
||||
features = self.model.forward_features(image)
|
||||
tokens = features['x_norm_patchtokens']
|
||||
image_features = features['x_norm_clstoken']
|
||||
image_features = image_features.unsqueeze(1)
|
||||
hint = torch.cat([image_features, tokens], 1) # 8,257,1024
|
||||
hint = self.projector(hint)
|
||||
return hint
|
||||
|
||||
def encode(self, image):
|
||||
return self(image)
|
||||
221
modelscope/models/cv/anydoor/ldm/util.py
Normal file
221
modelscope/models/cv/anydoor/ldm/util.py
Normal file
@@ -0,0 +1,221 @@
|
||||
import importlib
|
||||
from inspect import isfunction
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from torch import optim
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new('RGB', wh, color='white')
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('font/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = '\n'.join(xc[bi][start:start + nc]
|
||||
for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill='black', font=font)
|
||||
except UnicodeEncodeError:
|
||||
print('Cant encode string for logging. Skipping.')
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(
|
||||
f'{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.'
|
||||
)
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if 'target' not in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == '__is_unconditional__':
|
||||
return None
|
||||
raise KeyError('Expected key `target` to instantiate.')
|
||||
return get_obj_from_str(config['target'])(**config.get('params', dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit('.', 1)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self,
|
||||
params,
|
||||
lr=1.e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1.e-8,
|
||||
weight_decay=1.e-2,
|
||||
amsgrad=False,
|
||||
ema_decay=0.9999,
|
||||
ema_power=1.,
|
||||
param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError('Invalid learning rate: {}'.format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError('Invalid epsilon value: {}'.format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError('Invalid beta parameter at index 0: {}'.format(
|
||||
betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError('Invalid beta parameter at index 1: {}'.format(
|
||||
betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError(
|
||||
'Invalid weight_decay value: {}'.format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError('Invalid ema_decay value: {}'.format(ema_decay))
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
betas=betas,
|
||||
eps=eps,
|
||||
weight_decay=weight_decay,
|
||||
amsgrad=amsgrad,
|
||||
ema_decay=ema_decay,
|
||||
ema_power=ema_power,
|
||||
param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
'AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(
|
||||
params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step']**-ema_power)
|
||||
for param, ema_param in zip(params_with_grad,
|
||||
ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(
|
||||
param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
2
modelscope/models/cv/image_to_3d/__init__.py
Normal file
2
modelscope/models/cv/image_to_3d/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved.
|
||||
from . import ldm
|
||||
158
modelscope/models/cv/image_to_3d/ldm/base_utils.py
Normal file
158
modelscope/models/cv/image_to_3d/ldm/base_utils.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import cv2
|
||||
from skimage.io import imread
|
||||
|
||||
|
||||
def save_pickle(data, pkl_path):
|
||||
# os.system('mkdir -p {}'.format(os.path.dirname(pkl_path)))
|
||||
with open(pkl_path, 'wb') as f:
|
||||
pickle.dump(data, f)
|
||||
|
||||
def read_pickle(pkl_path):
|
||||
with open(pkl_path, 'rb') as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def draw_epipolar_line(F, img0, img1, pt0, color):
|
||||
h1,w1=img1.shape[:2]
|
||||
hpt = np.asarray([pt0[0], pt0[1], 1], dtype=np.float32)[:, None]
|
||||
l = F @ hpt
|
||||
l = l[:, 0]
|
||||
a, b, c = l[0], l[1], l[2]
|
||||
pt1 = np.asarray([0, -c / b]).astype(np.int32)
|
||||
pt2 = np.asarray([w1, (-a * w1 - c) / b]).astype(np.int32)
|
||||
|
||||
img0 = cv2.circle(img0, tuple(pt0.astype(np.int32)), 5, color, 2)
|
||||
img1 = cv2.line(img1, tuple(pt1), tuple(pt2), color, 2)
|
||||
return img0, img1
|
||||
|
||||
def draw_epipolar_lines(F, img0, img1,num=20):
|
||||
img0,img1=img0.copy(),img1.copy()
|
||||
h0, w0, _ = img0.shape
|
||||
h1, w1, _ = img1.shape
|
||||
|
||||
for k in range(num):
|
||||
color = np.random.randint(0, 255, [3], dtype=np.int32)
|
||||
color = [int(c) for c in color]
|
||||
pt = np.random.uniform(0, 1, 2)
|
||||
pt[0] *= w0
|
||||
pt[1] *= h0
|
||||
pt = pt.astype(np.int32)
|
||||
img0, img1 = draw_epipolar_line(F, img0, img1, pt, color)
|
||||
|
||||
return img0, img1
|
||||
|
||||
def compute_F(K1, K2, Rt0, Rt1=None):
|
||||
if Rt1 is None:
|
||||
R, t = Rt0[:,:3], Rt0[:,3:]
|
||||
else:
|
||||
Rt = compute_dR_dt(Rt0,Rt1)
|
||||
R, t = Rt[:,:3], Rt[:,3:]
|
||||
A = K1 @ R.T @ t # [3,1]
|
||||
C = np.asarray([[0,-A[2,0],A[1,0]],
|
||||
[A[2,0],0,-A[0,0]],
|
||||
[-A[1,0],A[0,0],0]])
|
||||
F = (np.linalg.inv(K2)).T @ R @ K1.T @ C
|
||||
return F
|
||||
|
||||
def compute_dR_dt(Rt0, Rt1):
|
||||
R0, t0 = Rt0[:,:3], Rt0[:,3:]
|
||||
R1, t1 = Rt1[:,:3], Rt1[:,3:]
|
||||
dR = np.dot(R1, R0.T)
|
||||
dt = t1 - np.dot(dR, t0)
|
||||
return np.concatenate([dR, dt], -1)
|
||||
|
||||
def concat_images(img0,img1,vert=False):
|
||||
if not vert:
|
||||
h0,h1=img0.shape[0],img1.shape[0],
|
||||
if h0<h1: img0=cv2.copyMakeBorder(img0,0,h1-h0,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
if h1<h0: img1=cv2.copyMakeBorder(img1,0,h0-h1,0,0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
img = np.concatenate([img0, img1], axis=1)
|
||||
else:
|
||||
w0,w1=img0.shape[1],img1.shape[1]
|
||||
if w0<w1: img0=cv2.copyMakeBorder(img0,0,0,0,w1-w0,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
if w1<w0: img1=cv2.copyMakeBorder(img1,0,0,0,w0-w1,borderType=cv2.BORDER_CONSTANT,value=0)
|
||||
img = np.concatenate([img0, img1], axis=0)
|
||||
|
||||
return img
|
||||
|
||||
def concat_images_list(*args,vert=False):
|
||||
if len(args)==1: return args[0]
|
||||
img_out=args[0]
|
||||
for img in args[1:]:
|
||||
img_out=concat_images(img_out,img,vert)
|
||||
return img_out
|
||||
|
||||
|
||||
def pose_inverse(pose):
|
||||
R = pose[:,:3].T
|
||||
t = - R @ pose[:,3:]
|
||||
return np.concatenate([R,t],-1)
|
||||
|
||||
def project_points(pts,RT,K):
|
||||
pts = np.matmul(pts,RT[:,:3].transpose())+RT[:,3:].transpose()
|
||||
pts = np.matmul(pts,K.transpose())
|
||||
dpt = pts[:,2]
|
||||
mask0 = (np.abs(dpt)<1e-4) & (np.abs(dpt)>0)
|
||||
if np.sum(mask0)>0: dpt[mask0]=1e-4
|
||||
mask1=(np.abs(dpt) > -1e-4) & (np.abs(dpt) < 0)
|
||||
if np.sum(mask1)>0: dpt[mask1]=-1e-4
|
||||
pts2d = pts[:,:2]/dpt[:,None]
|
||||
return pts2d, dpt
|
||||
|
||||
|
||||
def draw_keypoints(img, kps, colors=None, radius=2):
|
||||
out_img=img.copy()
|
||||
for pi, pt in enumerate(kps):
|
||||
pt = np.round(pt).astype(np.int32)
|
||||
if colors is not None:
|
||||
color=[int(c) for c in colors[pi]]
|
||||
cv2.circle(out_img, tuple(pt), radius, color, -1)
|
||||
else:
|
||||
cv2.circle(out_img, tuple(pt), radius, (0,255,0), -1)
|
||||
return out_img
|
||||
|
||||
|
||||
def output_points(fn,pts,colors=None):
|
||||
with open(fn, 'w') as f:
|
||||
for pi, pt in enumerate(pts):
|
||||
f.write(f'{pt[0]:.6f} {pt[1]:.6f} {pt[2]:.6f} ')
|
||||
if colors is not None:
|
||||
f.write(f'{int(colors[pi,0])} {int(colors[pi,1])} {int(colors[pi,2])}')
|
||||
f.write('\n')
|
||||
|
||||
DEPTH_MAX, DEPTH_MIN = 2.4, 0.6
|
||||
DEPTH_VALID_MAX, DEPTH_VALID_MIN = 2.37, 0.63
|
||||
def read_depth_objaverse(depth_fn):
|
||||
depth = imread(depth_fn)
|
||||
depth = depth.astype(np.float32) / 65535 * (DEPTH_MAX-DEPTH_MIN) + DEPTH_MIN
|
||||
mask = (depth > DEPTH_VALID_MIN) & (depth < DEPTH_VALID_MAX)
|
||||
return depth, mask
|
||||
|
||||
|
||||
def mask_depth_to_pts(mask,depth,K,rgb=None):
|
||||
hs,ws=np.nonzero(mask)
|
||||
depth=depth[hs,ws]
|
||||
pts=np.asarray([ws,hs,depth],np.float32).transpose()
|
||||
pts[:,:2]*=pts[:,2:]
|
||||
if rgb is not None:
|
||||
return np.dot(pts, np.linalg.inv(K).transpose()), rgb[hs,ws]
|
||||
else:
|
||||
return np.dot(pts, np.linalg.inv(K).transpose())
|
||||
|
||||
def transform_points_pose(pts, pose):
|
||||
R, t = pose[:, :3], pose[:, 3]
|
||||
if len(pts.shape)==1:
|
||||
return (R @ pts[:,None] + t[:,None])[:,0]
|
||||
return pts @ R.T + t[None,:]
|
||||
|
||||
def pose_apply(pose,pts):
|
||||
return transform_points_pose(pts, pose)
|
||||
|
||||
def downsample_gaussian_blur(img, ratio):
|
||||
sigma = (1 / ratio) / 3
|
||||
# ksize=np.ceil(2*sigma)
|
||||
ksize = int(np.ceil(((sigma - 0.8) / 0.3 + 1) * 2 + 1))
|
||||
ksize = ksize + 1 if ksize % 2 == 0 else ksize
|
||||
img = cv2.GaussianBlur(img, (ksize, ksize), sigma, borderType=cv2.BORDER_REFLECT101)
|
||||
return img
|
||||
443
modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py
Normal file
443
modelscope/models/cv/image_to_3d/ldm/models/autoencoder.py
Normal file
@@ -0,0 +1,443 @@
|
||||
import torch
|
||||
import pytorch_lightning as pl
|
||||
import torch.nn.functional as F
|
||||
from contextlib import contextmanager
|
||||
|
||||
from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.model import Encoder, Decoder
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.distributions.distributions import DiagonalGaussianDistribution
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
class VQModel(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
n_embed,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
batch_resize_range=None,
|
||||
scheduler_config=None,
|
||||
lr_g_factor=1.0,
|
||||
remap=None,
|
||||
sane_index_shape=False, # tell vector quantizer to return indices as bhw
|
||||
use_ema=False
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = embed_dim
|
||||
self.n_embed = n_embed
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
|
||||
remap=remap,
|
||||
sane_index_shape=sane_index_shape)
|
||||
self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
self.batch_resize_range = batch_resize_range
|
||||
if self.batch_resize_range is not None:
|
||||
print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")
|
||||
|
||||
self.use_ema = use_ema
|
||||
if self.use_ema:
|
||||
self.model_ema = LitEma(self)
|
||||
print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")
|
||||
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
self.scheduler_config = scheduler_config
|
||||
self.lr_g_factor = lr_g_factor
|
||||
|
||||
@contextmanager
|
||||
def ema_scope(self, context=None):
|
||||
if self.use_ema:
|
||||
self.model_ema.store(self.parameters())
|
||||
self.model_ema.copy_to(self)
|
||||
if context is not None:
|
||||
print(f"{context}: Switched to EMA weights")
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if self.use_ema:
|
||||
self.model_ema.restore(self.parameters())
|
||||
if context is not None:
|
||||
print(f"{context}: Restored training weights")
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
missing, unexpected = self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
|
||||
if len(missing) > 0:
|
||||
print(f"Missing Keys: {missing}")
|
||||
print(f"Unexpected Keys: {unexpected}")
|
||||
|
||||
def on_train_batch_end(self, *args, **kwargs):
|
||||
if self.use_ema:
|
||||
self.model_ema(self)
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
return quant, emb_loss, info
|
||||
|
||||
def encode_to_prequant(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, quant):
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
def decode_code(self, code_b):
|
||||
quant_b = self.quantize.embed_code(code_b)
|
||||
dec = self.decode(quant_b)
|
||||
return dec
|
||||
|
||||
def forward(self, input, return_pred_indices=False):
|
||||
quant, diff, (_,_,ind) = self.encode(input)
|
||||
dec = self.decode(quant)
|
||||
if return_pred_indices:
|
||||
return dec, diff, ind
|
||||
return dec, diff
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
if self.batch_resize_range is not None:
|
||||
lower_size = self.batch_resize_range[0]
|
||||
upper_size = self.batch_resize_range[1]
|
||||
if self.global_step <= 4:
|
||||
# do the first few batches with max size to avoid later oom
|
||||
new_resize = upper_size
|
||||
else:
|
||||
new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16))
|
||||
if new_resize != x.shape[2]:
|
||||
x = F.interpolate(x, size=new_resize, mode="bicubic")
|
||||
x = x.detach()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
# https://github.com/pytorch/pytorch/issues/37142
|
||||
# try not to fool the heuristics
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# autoencode
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train",
|
||||
predicted_indices=ind)
|
||||
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# discriminator
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
log_dict = self._validation_step(batch, batch_idx)
|
||||
with self.ema_scope():
|
||||
log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
|
||||
return log_dict
|
||||
|
||||
def _validation_step(self, batch, batch_idx, suffix=""):
|
||||
x = self.get_input(batch, self.image_key)
|
||||
xrec, qloss, ind = self(x, return_pred_indices=True)
|
||||
aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
|
||||
discloss, log_dict_disc = self.loss(qloss, x, xrec, 1,
|
||||
self.global_step,
|
||||
last_layer=self.get_last_layer(),
|
||||
split="val"+suffix,
|
||||
predicted_indices=ind
|
||||
)
|
||||
rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log(f"val{suffix}/rec_loss", rec_loss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
self.log(f"val{suffix}/aeloss", aeloss,
|
||||
prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
|
||||
if version.parse(pl.__version__) >= version.parse('1.4.0'):
|
||||
del log_dict_ae[f"val{suffix}/rec_loss"]
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr_d = self.learning_rate
|
||||
lr_g = self.lr_g_factor*self.learning_rate
|
||||
print("lr_d", lr_d)
|
||||
print("lr_g", lr_g)
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quantize.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr_g, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr_d, betas=(0.5, 0.9))
|
||||
|
||||
if self.scheduler_config is not None:
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [
|
||||
{
|
||||
'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
{
|
||||
'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule),
|
||||
'interval': 'step',
|
||||
'frequency': 1
|
||||
},
|
||||
]
|
||||
return [opt_ae, opt_disc], scheduler
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if only_inputs:
|
||||
log["inputs"] = x
|
||||
return log
|
||||
xrec, _ = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["inputs"] = x
|
||||
log["reconstructions"] = xrec
|
||||
if plot_ema:
|
||||
with self.ema_scope():
|
||||
xrec_ema, _ = self(x)
|
||||
if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema)
|
||||
log["reconstructions_ema"] = xrec_ema
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class VQModelInterface(VQModel):
|
||||
def __init__(self, embed_dim, *args, **kwargs):
|
||||
super().__init__(embed_dim=embed_dim, *args, **kwargs)
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
def decode(self, h, force_not_quantize=False):
|
||||
# also go through quantization layer
|
||||
if not force_not_quantize:
|
||||
quant, emb_loss, info = self.quantize(h)
|
||||
else:
|
||||
quant = h
|
||||
quant = self.post_quant_conv(quant)
|
||||
dec = self.decoder(quant)
|
||||
return dec
|
||||
|
||||
|
||||
class AutoencoderKL(pl.LightningModule):
|
||||
def __init__(self,
|
||||
ddconfig,
|
||||
lossconfig,
|
||||
embed_dim,
|
||||
ckpt_path=None,
|
||||
ignore_keys=[],
|
||||
image_key="image",
|
||||
colorize_nlabels=None,
|
||||
monitor=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_key = image_key
|
||||
self.encoder = Encoder(**ddconfig)
|
||||
self.decoder = Decoder(**ddconfig)
|
||||
self.loss = instantiate_from_config(lossconfig)
|
||||
assert ddconfig["double_z"]
|
||||
self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
|
||||
self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
|
||||
self.embed_dim = embed_dim
|
||||
if colorize_nlabels is not None:
|
||||
assert type(colorize_nlabels)==int
|
||||
self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
|
||||
if monitor is not None:
|
||||
self.monitor = monitor
|
||||
if ckpt_path is not None:
|
||||
self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
|
||||
|
||||
def init_from_ckpt(self, path, ignore_keys=list()):
|
||||
sd = torch.load(path, map_location="cpu")["state_dict"]
|
||||
keys = list(sd.keys())
|
||||
for k in keys:
|
||||
for ik in ignore_keys:
|
||||
if k.startswith(ik):
|
||||
print("Deleting key {} from state_dict.".format(k))
|
||||
del sd[k]
|
||||
self.load_state_dict(sd, strict=False)
|
||||
print(f"Restored from {path}")
|
||||
|
||||
def encode(self, x):
|
||||
h = self.encoder(x)
|
||||
moments = self.quant_conv(h)
|
||||
posterior = DiagonalGaussianDistribution(moments)
|
||||
return posterior
|
||||
|
||||
def decode(self, z):
|
||||
z = self.post_quant_conv(z)
|
||||
dec = self.decoder(z)
|
||||
return dec
|
||||
|
||||
def forward(self, input, sample_posterior=True):
|
||||
posterior = self.encode(input)
|
||||
if sample_posterior:
|
||||
z = posterior.sample()
|
||||
else:
|
||||
z = posterior.mode()
|
||||
dec = self.decode(z)
|
||||
return dec, posterior
|
||||
|
||||
def get_input(self, batch, k):
|
||||
x = batch[k]
|
||||
if len(x.shape) == 3:
|
||||
x = x[..., None]
|
||||
x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float()
|
||||
return x
|
||||
|
||||
def training_step(self, batch, batch_idx, optimizer_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
|
||||
if optimizer_idx == 0:
|
||||
# train encoder+decoder+logvar
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return aeloss
|
||||
|
||||
if optimizer_idx == 1:
|
||||
# train the discriminator
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="train")
|
||||
|
||||
self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True)
|
||||
self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False)
|
||||
return discloss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
inputs = self.get_input(batch, self.image_key)
|
||||
reconstructions, posterior = self(inputs)
|
||||
aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step,
|
||||
last_layer=self.get_last_layer(), split="val")
|
||||
|
||||
self.log("val/rec_loss", log_dict_ae["val/rec_loss"])
|
||||
self.log_dict(log_dict_ae)
|
||||
self.log_dict(log_dict_disc)
|
||||
return self.log_dict
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
|
||||
list(self.decoder.parameters())+
|
||||
list(self.quant_conv.parameters())+
|
||||
list(self.post_quant_conv.parameters()),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(),
|
||||
lr=lr, betas=(0.5, 0.9))
|
||||
return [opt_ae, opt_disc], []
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.decoder.conv_out.weight
|
||||
|
||||
@torch.no_grad()
|
||||
def log_images(self, batch, only_inputs=False, **kwargs):
|
||||
log = dict()
|
||||
x = self.get_input(batch, self.image_key)
|
||||
x = x.to(self.device)
|
||||
if not only_inputs:
|
||||
xrec, posterior = self(x)
|
||||
if x.shape[1] > 3:
|
||||
# colorize with random projection
|
||||
assert xrec.shape[1] > 3
|
||||
x = self.to_rgb(x)
|
||||
xrec = self.to_rgb(xrec)
|
||||
log["samples"] = self.decode(torch.randn_like(posterior.sample()))
|
||||
log["reconstructions"] = xrec
|
||||
log["inputs"] = x
|
||||
return log
|
||||
|
||||
def to_rgb(self, x):
|
||||
assert self.image_key == "segmentation"
|
||||
if not hasattr(self, "colorize"):
|
||||
self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
|
||||
x = F.conv2d(x, weight=self.colorize)
|
||||
x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
|
||||
return x
|
||||
|
||||
|
||||
class IdentityFirstStage(torch.nn.Module):
|
||||
def __init__(self, *args, vq_interface=False, **kwargs):
|
||||
self.vq_interface = vq_interface # TODO: Should be true by default but check to not break older stuff
|
||||
super().__init__()
|
||||
|
||||
def encode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def decode(self, x, *args, **kwargs):
|
||||
return x
|
||||
|
||||
def quantize(self, x, *args, **kwargs):
|
||||
if self.vq_interface:
|
||||
return x, None, [None, None, None]
|
||||
return x
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return x
|
||||
@@ -0,0 +1,673 @@
|
||||
from pathlib import Path
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
from skimage.io import imsave
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.base_utils import read_pickle, concat_images_list
|
||||
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_utils import get_warp_coordinates, create_target_volume
|
||||
from modelscope.models.cv.image_to_3d.ldm.models.diffusion.sync_dreamer_network import NoisyTargetViewEncoder, SpatialTime3DNet, FrustumTV3DNet
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_ddim_timesteps, timestep_embedding
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.modules import FrozenCLIPImageEmbedder
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
def disable_training_module(module: nn.Module):
|
||||
module = module.eval()
|
||||
module.train = disabled_train
|
||||
for para in module.parameters():
|
||||
para.requires_grad = False
|
||||
return module
|
||||
|
||||
def repeat_to_batch(tensor, B, VN):
|
||||
t_shape = tensor.shape
|
||||
ones = [1 for _ in range(len(t_shape)-1)]
|
||||
tensor_new = tensor.view(B,1,*t_shape[1:]).repeat(1,VN,*ones).view(B*VN,*t_shape[1:])
|
||||
return tensor_new
|
||||
|
||||
class UNetWrapper(nn.Module):
|
||||
def __init__(self, diff_model_config, drop_conditions=False, drop_scheme='default', use_zero_123=True):
|
||||
super().__init__()
|
||||
self.diffusion_model = instantiate_from_config(diff_model_config)
|
||||
self.drop_conditions = drop_conditions
|
||||
self.drop_scheme=drop_scheme
|
||||
self.use_zero_123 = use_zero_123
|
||||
|
||||
|
||||
def drop(self, cond, mask):
|
||||
shape = cond.shape
|
||||
B = shape[0]
|
||||
cond = mask.view(B,*[1 for _ in range(len(shape)-1)]) * cond
|
||||
return cond
|
||||
|
||||
def get_trainable_parameters(self):
|
||||
return self.diffusion_model.get_trainable_parameters()
|
||||
|
||||
def get_drop_scheme(self, B, device):
|
||||
if self.drop_scheme=='default':
|
||||
random = torch.rand(B, dtype=torch.float32, device=device)
|
||||
drop_clip = (random > 0.15) & (random <= 0.2)
|
||||
drop_volume = (random > 0.1) & (random <= 0.15)
|
||||
drop_concat = (random > 0.05) & (random <= 0.1)
|
||||
drop_all = random <= 0.05
|
||||
else:
|
||||
raise NotImplementedError
|
||||
return drop_clip, drop_volume, drop_concat, drop_all
|
||||
|
||||
def forward(self, x, t, clip_embed, volume_feats, x_concat, is_train=False):
|
||||
"""
|
||||
|
||||
@param x: B,4,H,W
|
||||
@param t: B,
|
||||
@param clip_embed: B,M,768
|
||||
@param volume_feats: B,C,D,H,W
|
||||
@param x_concat: B,C,H,W
|
||||
@param is_train:
|
||||
@return:
|
||||
"""
|
||||
if self.drop_conditions and is_train:
|
||||
B = x.shape[0]
|
||||
drop_clip, drop_volume, drop_concat, drop_all = self.get_drop_scheme(B, x.device)
|
||||
|
||||
clip_mask = 1.0 - (drop_clip | drop_all).float()
|
||||
clip_embed = self.drop(clip_embed, clip_mask)
|
||||
|
||||
volume_mask = 1.0 - (drop_volume | drop_all).float()
|
||||
for k, v in volume_feats.items():
|
||||
volume_feats[k] = self.drop(v, mask=volume_mask)
|
||||
|
||||
concat_mask = 1.0 - (drop_concat | drop_all).float()
|
||||
x_concat = self.drop(x_concat, concat_mask)
|
||||
|
||||
if self.use_zero_123:
|
||||
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
||||
first_stage_scale_factor = 0.18215
|
||||
x_concat_ = x_concat * 1.0
|
||||
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
|
||||
else:
|
||||
x_concat_ = x_concat
|
||||
|
||||
x = torch.cat([x, x_concat_], 1)
|
||||
pred = self.diffusion_model(x, t, clip_embed, source_dict=volume_feats)
|
||||
return pred
|
||||
|
||||
def predict_with_unconditional_scale(self, x, t, clip_embed, volume_feats, x_concat, unconditional_scale):
|
||||
x_ = torch.cat([x] * 2, 0)
|
||||
t_ = torch.cat([t] * 2, 0)
|
||||
clip_embed_ = torch.cat([clip_embed, torch.zeros_like(clip_embed)], 0)
|
||||
|
||||
v_ = {}
|
||||
for k, v in volume_feats.items():
|
||||
v_[k] = torch.cat([v, torch.zeros_like(v)], 0)
|
||||
|
||||
x_concat_ = torch.cat([x_concat, torch.zeros_like(x_concat)], 0)
|
||||
if self.use_zero_123:
|
||||
# zero123 does not multiply this when encoding, maybe a bug for zero123
|
||||
first_stage_scale_factor = 0.18215
|
||||
x_concat_[:, :4] = x_concat_[:, :4] / first_stage_scale_factor
|
||||
x_ = torch.cat([x_, x_concat_], 1)
|
||||
s, s_uc = self.diffusion_model(x_, t_, clip_embed_, source_dict=v_).chunk(2)
|
||||
s = s_uc + unconditional_scale * (s - s_uc)
|
||||
return s
|
||||
|
||||
|
||||
class SpatialVolumeNet(nn.Module):
|
||||
def __init__(self, time_dim, view_dim, view_num,
|
||||
input_image_size=256, frustum_volume_depth=48,
|
||||
spatial_volume_size=32, spatial_volume_length=0.5,
|
||||
frustum_volume_length=0.86603 # sqrt(3)/2
|
||||
):
|
||||
super().__init__()
|
||||
self.target_encoder = NoisyTargetViewEncoder(time_dim, view_dim, output_dim=16)
|
||||
self.spatial_volume_feats = SpatialTime3DNet(input_dim=16 * view_num, time_dim=time_dim, dims=(64, 128, 256, 512))
|
||||
self.frustum_volume_feats = FrustumTV3DNet(64, time_dim, view_dim, dims=(64, 128, 256, 512))
|
||||
|
||||
self.frustum_volume_length = frustum_volume_length
|
||||
self.input_image_size = input_image_size
|
||||
self.spatial_volume_size = spatial_volume_size
|
||||
self.spatial_volume_length = spatial_volume_length
|
||||
|
||||
self.frustum_volume_size = self.input_image_size // 8
|
||||
self.frustum_volume_depth = frustum_volume_depth
|
||||
self.time_dim = time_dim
|
||||
self.view_dim = view_dim
|
||||
self.default_origin_depth = 1.5 # our rendered images are 1.5 away from the origin, we assume camera is 1.5 away from the origin
|
||||
|
||||
def construct_spatial_volume(self, x, t_embed, v_embed, target_poses, target_Ks):
|
||||
"""
|
||||
@param x: B,N,4,H,W
|
||||
@param t_embed: B,t_dim
|
||||
@param v_embed: B,N,v_dim
|
||||
@param target_poses: N,3,4
|
||||
@param target_Ks: N,3,3
|
||||
@return:
|
||||
"""
|
||||
B, N, _, H, W = x.shape
|
||||
V = self.spatial_volume_size
|
||||
device = x.device
|
||||
|
||||
spatial_volume_verts = torch.linspace(-self.spatial_volume_length, self.spatial_volume_length, V, dtype=torch.float32, device=device)
|
||||
spatial_volume_verts = torch.stack(torch.meshgrid(spatial_volume_verts, spatial_volume_verts, spatial_volume_verts), -1)
|
||||
spatial_volume_verts = spatial_volume_verts.reshape(1, V ** 3, 3)[:, :, (2, 1, 0)]
|
||||
spatial_volume_verts = spatial_volume_verts.view(1, V, V, V, 3).permute(0, 4, 1, 2, 3).repeat(B, 1, 1, 1, 1)
|
||||
|
||||
# encode source features
|
||||
t_embed_ = t_embed.view(B, 1, self.time_dim).repeat(1, N, 1).view(B, N, self.time_dim)
|
||||
# v_embed_ = v_embed.view(1, N, self.view_dim).repeat(B, 1, 1).view(B, N, self.view_dim)
|
||||
v_embed_ = v_embed
|
||||
target_Ks = target_Ks.unsqueeze(0).repeat(B, 1, 1, 1)
|
||||
target_poses = target_poses.unsqueeze(0).repeat(B, 1, 1, 1)
|
||||
|
||||
# extract 2D image features
|
||||
spatial_volume_feats = []
|
||||
# project source features
|
||||
for ni in range(0, N):
|
||||
pose_source_ = target_poses[:, ni]
|
||||
K_source_ = target_Ks[:, ni]
|
||||
x_ = self.target_encoder(x[:, ni], t_embed_[:, ni], v_embed_[:, ni])
|
||||
C = x_.shape[1]
|
||||
|
||||
coords_source = get_warp_coordinates(spatial_volume_verts, x_.shape[-1], self.input_image_size, K_source_, pose_source_).view(B, V, V * V, 2)
|
||||
unproj_feats_ = F.grid_sample(x_, coords_source, mode='bilinear', padding_mode='zeros', align_corners=True)
|
||||
unproj_feats_ = unproj_feats_.view(B, C, V, V, V)
|
||||
spatial_volume_feats.append(unproj_feats_)
|
||||
|
||||
spatial_volume_feats = torch.stack(spatial_volume_feats, 1) # B,N,C,V,V,V
|
||||
N = spatial_volume_feats.shape[1]
|
||||
spatial_volume_feats = spatial_volume_feats.view(B, N*C, V, V, V)
|
||||
|
||||
spatial_volume_feats = self.spatial_volume_feats(spatial_volume_feats, t_embed) # b,64,32,32,32
|
||||
return spatial_volume_feats
|
||||
|
||||
def construct_view_frustum_volume(self, spatial_volume, t_embed, v_embed, poses, Ks, target_indices):
|
||||
"""
|
||||
@param spatial_volume: B,C,V,V,V
|
||||
@param t_embed: B,t_dim
|
||||
@param v_embed: B,N,v_dim
|
||||
@param poses: N,3,4
|
||||
@param Ks: N,3,3
|
||||
@param target_indices: B,TN
|
||||
@return: B*TN,C,H,W
|
||||
"""
|
||||
B, TN = target_indices.shape
|
||||
H, W = self.frustum_volume_size, self.frustum_volume_size
|
||||
D = self.frustum_volume_depth
|
||||
V = self.spatial_volume_size
|
||||
|
||||
near = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth - self.frustum_volume_length
|
||||
far = torch.ones(B * TN, 1, H, W, dtype=spatial_volume.dtype, device=spatial_volume.device) * self.default_origin_depth + self.frustum_volume_length
|
||||
|
||||
target_indices = target_indices.view(B*TN) # B*TN
|
||||
poses_ = poses[target_indices] # B*TN,3,4
|
||||
Ks_ = Ks[target_indices] # B*TN,3,4
|
||||
volume_xyz, volume_depth = create_target_volume(D, self.frustum_volume_size, self.input_image_size, poses_, Ks_, near, far) # B*TN,3 or 1,D,H,W
|
||||
|
||||
volume_xyz_ = volume_xyz / self.spatial_volume_length # since the spatial volume is constructed in [-spatial_volume_length,spatial_volume_length]
|
||||
volume_xyz_ = volume_xyz_.permute(0, 2, 3, 4, 1) # B*TN,D,H,W,3
|
||||
spatial_volume_ = spatial_volume.unsqueeze(1).repeat(1, TN, 1, 1, 1, 1).view(B * TN, -1, V, V, V)
|
||||
volume_feats = F.grid_sample(spatial_volume_, volume_xyz_, mode='bilinear', padding_mode='zeros', align_corners=True) # B*TN,C,D,H,W
|
||||
|
||||
v_embed_ = v_embed[torch.arange(B)[:,None], target_indices.view(B,TN)].view(B*TN, -1) # B*TN
|
||||
t_embed_ = t_embed.unsqueeze(1).repeat(1,TN,1).view(B*TN,-1)
|
||||
volume_feats_dict = self.frustum_volume_feats(volume_feats, t_embed_, v_embed_)
|
||||
return volume_feats_dict, volume_depth
|
||||
"""
|
||||
SyncDreamer is a SoTA Novel View Synthesis model which can generate 16 consistent views seamlessly.
|
||||
Please refer to: https://arxiv.org/abs/2309.03453 for more technique details.
|
||||
"""
|
||||
class SyncMultiviewDiffusion(pl.LightningModule):
|
||||
def __init__(self, unet_config, scheduler_config,
|
||||
finetune_unet=False, finetune_projection=True,
|
||||
view_num=16, image_size=256,
|
||||
cfg_scale=3.0, output_num=8, batch_view_num=4,
|
||||
drop_conditions=False, drop_scheme='default',
|
||||
clip_image_encoder_path="/apdcephfs/private_rondyliu/projects/clip/ViT-L-14.pt"):
|
||||
super().__init__()
|
||||
|
||||
self.finetune_unet = finetune_unet
|
||||
self.finetune_projection = finetune_projection
|
||||
|
||||
self.view_num = view_num
|
||||
self.viewpoint_dim = 4
|
||||
self.output_num = output_num
|
||||
self.image_size = image_size
|
||||
|
||||
self.batch_view_num = batch_view_num
|
||||
self.cfg_scale = cfg_scale
|
||||
|
||||
self.clip_image_encoder_path = clip_image_encoder_path
|
||||
|
||||
self._init_time_step_embedding()
|
||||
self._init_first_stage()
|
||||
self._init_schedule()
|
||||
self._init_multiview()
|
||||
self._init_clip_image_encoder()
|
||||
self._init_clip_projection()
|
||||
|
||||
self.spatial_volume = SpatialVolumeNet(self.time_embed_dim, self.viewpoint_dim, self.view_num)
|
||||
self.model = UNetWrapper(unet_config, drop_conditions=drop_conditions, drop_scheme=drop_scheme)
|
||||
self.scheduler_config = scheduler_config
|
||||
|
||||
latent_size = image_size//8
|
||||
self.ddim = SyncDDIMSampler(self, 200, "uniform", 1.0, latent_size=latent_size)
|
||||
|
||||
def _init_clip_projection(self):
|
||||
self.cc_projection = nn.Linear(772, 768)
|
||||
nn.init.eye_(list(self.cc_projection.parameters())[0][:768, :768])
|
||||
nn.init.zeros_(list(self.cc_projection.parameters())[1])
|
||||
self.cc_projection.requires_grad_(True)
|
||||
|
||||
if not self.finetune_projection:
|
||||
disable_training_module(self.cc_projection)
|
||||
|
||||
def _init_multiview(self):
|
||||
K, azs, _, _, poses = read_pickle(self.clip_image_encoder_path.replace("ViT-L-14.pt",f'camera-{self.view_num}.pkl'))
|
||||
default_image_size = 256
|
||||
ratio = self.image_size/default_image_size
|
||||
K = np.diag([ratio,ratio,1]) @ K
|
||||
K = torch.from_numpy(K.astype(np.float32)) # [3,3]
|
||||
K = K.unsqueeze(0).repeat(self.view_num,1,1) # N,3,3
|
||||
poses = torch.from_numpy(poses.astype(np.float32)) # N,3,4
|
||||
self.register_buffer('poses', poses)
|
||||
self.register_buffer('Ks', K)
|
||||
azs = (azs + np.pi) % (np.pi * 2) - np.pi # scale to [-pi,pi] and the index=0 has az=0
|
||||
self.register_buffer('azimuth', torch.from_numpy(azs.astype(np.float32)))
|
||||
|
||||
def get_viewpoint_embedding(self, batch_size, elevation_ref):
|
||||
"""
|
||||
@param batch_size:
|
||||
@param elevation_ref: B
|
||||
@return:
|
||||
"""
|
||||
azimuth_input = self.azimuth[0].unsqueeze(0) # 1
|
||||
azimuth_target = self.azimuth # N
|
||||
elevation_input = -elevation_ref # note that zero123 use a negative elevation here!!!
|
||||
elevation_target = -np.deg2rad(30)
|
||||
d_e = elevation_target - elevation_input # B
|
||||
N = self.azimuth.shape[0]
|
||||
B = batch_size
|
||||
d_e = d_e.unsqueeze(1).repeat(1, N)
|
||||
d_a = azimuth_target - azimuth_input # N
|
||||
d_a = d_a.unsqueeze(0).repeat(B, 1)
|
||||
d_z = torch.zeros_like(d_a)
|
||||
embedding = torch.stack([d_e, torch.sin(d_a), torch.cos(d_a), d_z], -1) # B,N,4
|
||||
return embedding
|
||||
|
||||
def _init_first_stage(self):
|
||||
first_stage_config={
|
||||
"target": "modelscope.models.cv.image_to_3d.ldm.models.autoencoder.AutoencoderKL",
|
||||
"params": {
|
||||
"embed_dim": 4,
|
||||
"monitor": "val/rec_loss",
|
||||
"ddconfig":{
|
||||
"double_z": True,
|
||||
"z_channels": 4,
|
||||
"resolution": self.image_size,
|
||||
"in_channels": 3,
|
||||
"out_ch": 3,
|
||||
"ch": 128,
|
||||
"ch_mult": [1,2,4,4],
|
||||
"num_res_blocks": 2,
|
||||
"attn_resolutions": [],
|
||||
"dropout": 0.0
|
||||
},
|
||||
"lossconfig": {"target": "torch.nn.Identity"},
|
||||
}
|
||||
}
|
||||
self.first_stage_scale_factor = 0.18215
|
||||
self.first_stage_model = instantiate_from_config(first_stage_config)
|
||||
self.first_stage_model = disable_training_module(self.first_stage_model)
|
||||
|
||||
def _init_clip_image_encoder(self):
|
||||
self.clip_image_encoder = FrozenCLIPImageEmbedder(model=self.clip_image_encoder_path)
|
||||
self.clip_image_encoder = disable_training_module(self.clip_image_encoder)
|
||||
|
||||
def _init_schedule(self):
|
||||
self.num_timesteps = 1000
|
||||
linear_start = 0.00085
|
||||
linear_end = 0.0120
|
||||
num_timesteps = 1000
|
||||
betas = torch.linspace(linear_start ** 0.5, linear_end ** 0.5, num_timesteps, dtype=torch.float32) ** 2 # T
|
||||
assert betas.shape[0] == self.num_timesteps
|
||||
|
||||
# all in float64 first
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, dim=0) # T
|
||||
alphas_cumprod_prev = torch.cat([torch.ones(1, dtype=torch.float64), alphas_cumprod[:-1]], 0)
|
||||
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) # T
|
||||
posterior_log_variance_clipped = torch.log(torch.clamp(posterior_variance, min=1e-20))
|
||||
posterior_log_variance_clipped = torch.clamp(posterior_log_variance_clipped, min=-10)
|
||||
|
||||
self.register_buffer("betas", betas.float())
|
||||
self.register_buffer("alphas", alphas.float())
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
|
||||
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod).float())
|
||||
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod).float())
|
||||
self.register_buffer("posterior_variance", posterior_variance.float())
|
||||
self.register_buffer('posterior_log_variance_clipped', posterior_log_variance_clipped.float())
|
||||
|
||||
def _init_time_step_embedding(self):
|
||||
self.time_embed_dim = 256
|
||||
self.time_embed = nn.Sequential(
|
||||
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
||||
nn.SiLU(True),
|
||||
nn.Linear(self.time_embed_dim, self.time_embed_dim),
|
||||
)
|
||||
|
||||
def encode_first_stage(self, x, sample=True):
|
||||
with torch.no_grad():
|
||||
posterior = self.first_stage_model.encode(x) # b,4,h//8,w//8
|
||||
if sample:
|
||||
return posterior.sample().detach() * self.first_stage_scale_factor
|
||||
else:
|
||||
return posterior.mode().detach() * self.first_stage_scale_factor
|
||||
|
||||
def decode_first_stage(self, z):
|
||||
with torch.no_grad():
|
||||
z = 1. / self.first_stage_scale_factor * z
|
||||
return self.first_stage_model.decode(z)
|
||||
|
||||
def prepare(self, batch):
|
||||
# encode target
|
||||
if 'target_image' in batch:
|
||||
image_target = batch['target_image'].permute(0, 1, 4, 2, 3) # b,n,3,h,w
|
||||
N = image_target.shape[1]
|
||||
x = [self.encode_first_stage(image_target[:,ni], True) for ni in range(N)]
|
||||
x = torch.stack(x, 1) # b,n,4,h//8,w//8
|
||||
else:
|
||||
x = None
|
||||
|
||||
image_input = batch['input_image'].permute(0, 3, 1, 2)
|
||||
elevation_input = batch['input_elevation'][:, 0] # b
|
||||
x_input = self.encode_first_stage(image_input)
|
||||
input_info = {'image': image_input, 'elevation': elevation_input, 'x': x_input}
|
||||
with torch.no_grad():
|
||||
clip_embed = self.clip_image_encoder.encode(image_input)
|
||||
return x, clip_embed, input_info
|
||||
|
||||
def embed_time(self, t):
|
||||
t_embed = timestep_embedding(t, self.time_embed_dim, repeat_only=False) # B,TED
|
||||
t_embed = self.time_embed(t_embed) # B,TED
|
||||
return t_embed
|
||||
|
||||
def get_target_view_feats(self, x_input, spatial_volume, clip_embed, t_embed, v_embed, target_index):
|
||||
"""
|
||||
@param x_input: B,4,H,W
|
||||
@param spatial_volume: B,C,V,V,V
|
||||
@param clip_embed: B,1,768
|
||||
@param t_embed: B,t_dim
|
||||
@param v_embed: B,N,v_dim
|
||||
@param target_index: B,TN
|
||||
@return:
|
||||
tensors of size B*TN,*
|
||||
"""
|
||||
B, _, H, W = x_input.shape
|
||||
frustum_volume_feats, frustum_volume_depth = self.spatial_volume.construct_view_frustum_volume(spatial_volume, t_embed, v_embed, self.poses, self.Ks, target_index)
|
||||
|
||||
# clip
|
||||
TN = target_index.shape[1]
|
||||
v_embed_ = v_embed[torch.arange(B)[:,None], target_index].view(B*TN, self.viewpoint_dim) # B*TN,v_dim
|
||||
clip_embed_ = clip_embed.unsqueeze(1).repeat(1,TN,1,1).view(B*TN,1,768)
|
||||
clip_embed_ = self.cc_projection(torch.cat([clip_embed_, v_embed_.unsqueeze(1)], -1)) # B*TN,1,768
|
||||
|
||||
x_input_ = x_input.unsqueeze(1).repeat(1, TN, 1, 1, 1).view(B * TN, 4, H, W)
|
||||
|
||||
x_concat = x_input_
|
||||
return clip_embed_, frustum_volume_feats, x_concat
|
||||
|
||||
def training_step(self, batch):
|
||||
B = batch['image'].shape[0]
|
||||
time_steps = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
|
||||
|
||||
x, clip_embed, input_info = self.prepare(batch)
|
||||
x_noisy, noise = self.add_noise(x, time_steps) # B,N,4,H,W
|
||||
|
||||
N = self.view_num
|
||||
target_index = torch.randint(0, N, (B, 1), device=self.device).long() # B, 1
|
||||
v_embed = self.get_viewpoint_embedding(B, input_info['elevation']) # N,v_dim
|
||||
|
||||
t_embed = self.embed_time(time_steps)
|
||||
spatial_volume = self.spatial_volume.construct_spatial_volume(x_noisy, t_embed, v_embed, self.poses, self.Ks)
|
||||
|
||||
clip_embed, volume_feats, x_concat = self.get_target_view_feats(input_info['x'], spatial_volume, clip_embed, t_embed, v_embed, target_index)
|
||||
|
||||
x_noisy_ = x_noisy[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
|
||||
noise_predict = self.model(x_noisy_, time_steps, clip_embed, volume_feats, x_concat, is_train=True) # B,4,H,W
|
||||
|
||||
noise_target = noise[torch.arange(B)[:,None],target_index][:,0] # B,4,H,W
|
||||
# loss simple for diffusion
|
||||
loss_simple = torch.nn.functional.mse_loss(noise_target, noise_predict, reduction='none')
|
||||
loss = loss_simple.mean()
|
||||
self.log('sim', loss_simple.mean(), prog_bar=True, logger=True, on_step=True, on_epoch=True, rank_zero_only=True)
|
||||
|
||||
# log others
|
||||
lr = self.optimizers().param_groups[0]['lr']
|
||||
self.log('lr', lr, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
||||
self.log("step", self.global_step, prog_bar=True, logger=True, on_step=True, on_epoch=False, rank_zero_only=True)
|
||||
return loss
|
||||
|
||||
def add_noise(self, x_start, t):
|
||||
"""
|
||||
@param x_start: B,*
|
||||
@param t: B,
|
||||
@return:
|
||||
"""
|
||||
B = x_start.shape[0]
|
||||
noise = torch.randn_like(x_start) # B,*
|
||||
|
||||
sqrt_alphas_cumprod_ = self.sqrt_alphas_cumprod[t] # B,
|
||||
sqrt_one_minus_alphas_cumprod_ = self.sqrt_one_minus_alphas_cumprod[t] # B
|
||||
sqrt_alphas_cumprod_ = sqrt_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
|
||||
sqrt_one_minus_alphas_cumprod_ = sqrt_one_minus_alphas_cumprod_.view(B, *[1 for _ in range(len(x_start.shape)-1)])
|
||||
x_noisy = sqrt_alphas_cumprod_ * x_start + sqrt_one_minus_alphas_cumprod_ * noise
|
||||
return x_noisy, noise
|
||||
|
||||
def sample(self, batch, cfg_scale, batch_view_num, use_ddim=True,
|
||||
return_inter_results=False, inter_interval=50, inter_view_interval=2):
|
||||
_, clip_embed, input_info = self.prepare(batch)
|
||||
if use_ddim:
|
||||
x_sample, inter = self.ddim.sample(input_info, clip_embed, unconditional_scale=cfg_scale, log_every_t=inter_interval, batch_view_num=batch_view_num)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
N = x_sample.shape[1]
|
||||
x_sample = torch.stack([self.decode_first_stage(x_sample[:, ni]) for ni in range(N)], 1)
|
||||
if return_inter_results:
|
||||
torch.cuda.synchronize()
|
||||
torch.cuda.empty_cache()
|
||||
inter = torch.stack(inter['x_inter'], 2) # # B,N,T,C,H,W
|
||||
B,N,T,C,H,W = inter.shape
|
||||
inter_results = []
|
||||
for ni in tqdm(range(0, N, inter_view_interval)):
|
||||
inter_results_ = []
|
||||
for ti in range(T):
|
||||
inter_results_.append(self.decode_first_stage(inter[:, ni, ti]))
|
||||
inter_results.append(torch.stack(inter_results_, 1)) # B,T,3,H,W
|
||||
inter_results = torch.stack(inter_results,1) # B,N,T,3,H,W
|
||||
return x_sample, inter_results
|
||||
else:
|
||||
return x_sample
|
||||
|
||||
def log_image(self, x_sample, batch, step, output_dir, only_first_row=False):
|
||||
process = lambda x: ((torch.clip(x, min=-1, max=1).cpu().numpy() * 0.5 + 0.5) * 255).astype(np.uint8)
|
||||
B = x_sample.shape[0]
|
||||
N = x_sample.shape[1]
|
||||
image_cond = []
|
||||
for bi in range(B):
|
||||
img_pr_ = concat_images_list(process(batch['ref_image'][bi]),*[process(x_sample[bi, ni].permute(1, 2, 0)) for ni in range(N)])
|
||||
img_gt_ = concat_images_list(process(batch['ref_image'][bi]),*[process(batch['image'][bi, ni]) for ni in range(N)])
|
||||
if not only_first_row or bi==0:
|
||||
image_cond.append(concat_images_list(img_gt_, img_pr_, vert=True))
|
||||
else:
|
||||
image_cond.append(img_pr_)
|
||||
|
||||
|
||||
output_dir = Path(output_dir)
|
||||
imsave(str(output_dir/f'{step}.jpg'), concat_images_list(*image_cond, vert=True))
|
||||
|
||||
@torch.no_grad()
|
||||
def validation_step(self, batch, batch_idx):
|
||||
if batch_idx==0 and self.global_rank==0:
|
||||
self.eval()
|
||||
step = self.global_step
|
||||
batch_ = {}
|
||||
for k, v in batch.items(): batch_[k] = v[:self.output_num]
|
||||
x_sample = self.sample(batch_, self.cfg_scale, self.batch_view_num)
|
||||
output_dir = Path(self.image_dir) / 'images' / 'val'
|
||||
output_dir.mkdir(exist_ok=True, parents=True)
|
||||
self.log_image(x_sample, batch, step, output_dir=output_dir)
|
||||
|
||||
def configure_optimizers(self):
|
||||
lr = self.learning_rate
|
||||
print(f'setting learning rate to {lr:.4f} ...')
|
||||
paras = []
|
||||
if self.finetune_projection:
|
||||
paras.append({"params": self.cc_projection.parameters(), "lr": lr},)
|
||||
if self.finetune_unet:
|
||||
paras.append({"params": self.model.parameters(), "lr": lr},)
|
||||
else:
|
||||
paras.append({"params": self.model.get_trainable_parameters(), "lr": lr},)
|
||||
|
||||
paras.append({"params": self.time_embed.parameters(), "lr": lr*10.0},)
|
||||
paras.append({"params": self.spatial_volume.parameters(), "lr": lr*10.0},)
|
||||
|
||||
opt = torch.optim.AdamW(paras, lr=lr)
|
||||
|
||||
scheduler = instantiate_from_config(self.scheduler_config)
|
||||
print("Setting up LambdaLR scheduler...")
|
||||
scheduler = [{'scheduler': LambdaLR(opt, lr_lambda=scheduler.schedule), 'interval': 'step', 'frequency': 1}]
|
||||
return [opt], scheduler
|
||||
|
||||
class SyncDDIMSampler:
|
||||
def __init__(self, model: SyncMultiviewDiffusion, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., latent_size=32):
|
||||
super().__init__()
|
||||
self.model = model
|
||||
self.ddpm_num_timesteps = model.num_timesteps
|
||||
self.latent_size = latent_size
|
||||
self._make_schedule(ddim_num_steps, ddim_discretize, ddim_eta)
|
||||
self.eta = ddim_eta
|
||||
|
||||
def _make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
|
||||
self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, num_ddpm_timesteps=self.ddpm_num_timesteps, verbose=verbose) # DT
|
||||
ddim_timesteps_ = torch.from_numpy(self.ddim_timesteps.astype(np.int64)) # DT
|
||||
|
||||
alphas_cumprod = self.model.alphas_cumprod # T
|
||||
assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
|
||||
self.ddim_alphas = alphas_cumprod[ddim_timesteps_].double() # DT
|
||||
self.ddim_alphas_prev = torch.cat([alphas_cumprod[0:1], alphas_cumprod[ddim_timesteps_[:-1]]], 0) # DT
|
||||
self.ddim_sigmas = ddim_eta * torch.sqrt((1 - self.ddim_alphas_prev) / (1 - self.ddim_alphas) * (1 - self.ddim_alphas / self.ddim_alphas_prev))
|
||||
|
||||
self.ddim_alphas_raw = self.model.alphas[ddim_timesteps_].float() # DT
|
||||
self.ddim_sigmas = self.ddim_sigmas.float()
|
||||
self.ddim_alphas = self.ddim_alphas.float()
|
||||
self.ddim_alphas_prev = self.ddim_alphas_prev.float()
|
||||
self.ddim_sqrt_one_minus_alphas = torch.sqrt(1. - self.ddim_alphas).float()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def denoise_apply_impl(self, x_target_noisy, index, noise_pred, is_step0=False):
|
||||
"""
|
||||
@param x_target_noisy: B,N,4,H,W
|
||||
@param index: index
|
||||
@param noise_pred: B,N,4,H,W
|
||||
@param is_step0: bool
|
||||
@return:
|
||||
"""
|
||||
device = x_target_noisy.device
|
||||
B,N,_,H,W = x_target_noisy.shape
|
||||
|
||||
# apply noise
|
||||
a_t = self.ddim_alphas[index].to(device).float().view(1,1,1,1,1)
|
||||
a_prev = self.ddim_alphas_prev[index].to(device).float().view(1,1,1,1,1)
|
||||
sqrt_one_minus_at = self.ddim_sqrt_one_minus_alphas[index].to(device).float().view(1,1,1,1,1)
|
||||
sigma_t = self.ddim_sigmas[index].to(device).float().view(1,1,1,1,1)
|
||||
|
||||
pred_x0 = (x_target_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt()
|
||||
dir_xt = torch.clamp(1. - a_prev - sigma_t**2, min=1e-7).sqrt() * noise_pred
|
||||
x_prev = a_prev.sqrt() * pred_x0 + dir_xt
|
||||
if not is_step0:
|
||||
noise = sigma_t * torch.randn_like(x_target_noisy)
|
||||
x_prev = x_prev + noise
|
||||
return x_prev
|
||||
|
||||
@torch.no_grad()
|
||||
def denoise_apply(self, x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=1, is_step0=False):
|
||||
"""
|
||||
@param x_target_noisy: B,N,4,H,W
|
||||
@param input_info:
|
||||
@param clip_embed: B,M,768
|
||||
@param time_steps: B,
|
||||
@param index: int
|
||||
@param unconditional_scale:
|
||||
@param batch_view_num: int
|
||||
@param is_step0: bool
|
||||
@return:
|
||||
"""
|
||||
x_input, elevation_input = input_info['x'], input_info['elevation']
|
||||
B, N, C, H, W = x_target_noisy.shape
|
||||
|
||||
# construct source data
|
||||
v_embed = self.model.get_viewpoint_embedding(B, elevation_input) # B,N,v_dim
|
||||
t_embed = self.model.embed_time(time_steps) # B,t_dim
|
||||
spatial_volume = self.model.spatial_volume.construct_spatial_volume(x_target_noisy, t_embed, v_embed, self.model.poses, self.model.Ks)
|
||||
|
||||
e_t = []
|
||||
target_indices = torch.arange(N) # N
|
||||
for ni in range(0, N, batch_view_num):
|
||||
x_target_noisy_ = x_target_noisy[:, ni:ni + batch_view_num]
|
||||
VN = x_target_noisy_.shape[1]
|
||||
x_target_noisy_ = x_target_noisy_.reshape(B*VN,C,H,W)
|
||||
|
||||
time_steps_ = repeat_to_batch(time_steps, B, VN)
|
||||
target_indices_ = target_indices[ni:ni+batch_view_num].unsqueeze(0).repeat(B,1)
|
||||
clip_embed_, volume_feats_, x_concat_ = self.model.get_target_view_feats(x_input, spatial_volume, clip_embed, t_embed, v_embed, target_indices_)
|
||||
if unconditional_scale!=1.0:
|
||||
noise = self.model.model.predict_with_unconditional_scale(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, unconditional_scale)
|
||||
else:
|
||||
noise = self.model.model(x_target_noisy_, time_steps_, clip_embed_, volume_feats_, x_concat_, is_train=False)
|
||||
e_t.append(noise.view(B,VN,4,H,W))
|
||||
|
||||
e_t = torch.cat(e_t, 1)
|
||||
x_prev = self.denoise_apply_impl(x_target_noisy, index, e_t, is_step0)
|
||||
return x_prev
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, input_info, clip_embed, unconditional_scale=1.0, log_every_t=50, batch_view_num=1):
|
||||
"""
|
||||
@param input_info: x, elevation
|
||||
@param clip_embed: B,M,768
|
||||
@param unconditional_scale:
|
||||
@param log_every_t:
|
||||
@param batch_view_num:
|
||||
@return:
|
||||
"""
|
||||
print(f"unconditional scale {unconditional_scale:.1f}")
|
||||
C, H, W = 4, self.latent_size, self.latent_size
|
||||
B = clip_embed.shape[0]
|
||||
N = self.model.view_num
|
||||
device = self.model.device
|
||||
x_target_noisy = torch.randn([B, N, C, H, W], device=device)
|
||||
|
||||
timesteps = self.ddim_timesteps
|
||||
intermediates = {'x_inter': []}
|
||||
time_range = np.flip(timesteps)
|
||||
total_steps = timesteps.shape[0]
|
||||
|
||||
iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps)
|
||||
for i, step in enumerate(iterator):
|
||||
index = total_steps - i - 1 # index in ddim state
|
||||
time_steps = torch.full((B,), step, device=device, dtype=torch.long)
|
||||
x_target_noisy = self.denoise_apply(x_target_noisy, input_info, clip_embed, time_steps, index, unconditional_scale, batch_view_num=batch_view_num, is_step0=index==0)
|
||||
if index % log_every_t == 0 or index == total_steps - 1:
|
||||
intermediates['x_inter'].append(x_target_noisy)
|
||||
|
||||
return x_target_noisy, intermediates
|
||||
@@ -0,0 +1,142 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import default, zero_module, checkpoint
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.openaimodel import UNetModel
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import timestep_embedding
|
||||
|
||||
class DepthAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim, heads, dim_head, output_bias=True):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.dim_head = dim_head
|
||||
|
||||
self.to_q = nn.Conv2d(query_dim, inner_dim, 1, 1, bias=False)
|
||||
self.to_k = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
|
||||
self.to_v = nn.Conv3d(context_dim, inner_dim, 1, 1, bias=False)
|
||||
if output_bias:
|
||||
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1)
|
||||
else:
|
||||
self.to_out = nn.Conv2d(inner_dim, query_dim, 1, 1, bias=False)
|
||||
|
||||
def forward(self, x, context):
|
||||
"""
|
||||
|
||||
@param x: b,f0,h,w
|
||||
@param context: b,f1,d,h,w
|
||||
@return:
|
||||
"""
|
||||
hn, hd = self.heads, self.dim_head
|
||||
b, _, h, w = x.shape
|
||||
b, _, d, h, w = context.shape
|
||||
|
||||
q = self.to_q(x).reshape(b,hn,hd,h,w) # b,t,h,w
|
||||
k = self.to_k(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
|
||||
v = self.to_v(context).reshape(b,hn,hd,d,h,w) # b,t,d,h,w
|
||||
|
||||
sim = torch.sum(q.unsqueeze(3) * k, 2) * self.scale # b,hn,d,h,w
|
||||
attn = sim.softmax(dim=2)
|
||||
|
||||
# b,hn,hd,d,h,w * b,hn,1,d,h,w
|
||||
out = torch.sum(v * attn.unsqueeze(2), 3) # b,hn,hd,h,w
|
||||
out = out.reshape(b,hn*hd,h,w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class DepthTransformer(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
|
||||
super().__init__()
|
||||
inner_dim = n_heads * d_head
|
||||
self.proj_in = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1),
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.SiLU(True),
|
||||
)
|
||||
self.proj_context = nn.Sequential(
|
||||
nn.Conv3d(context_dim, context_dim, 1, 1, bias=False), # no bias
|
||||
nn.GroupNorm(8, context_dim),
|
||||
nn.ReLU(True), # only relu, because we want input is 0, output is 0
|
||||
)
|
||||
self.depth_attn = DepthAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim, output_bias=False) # is a self-attention if not self.disable_self_attn
|
||||
self.proj_out = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(inner_dim, inner_dim, 3, 1, 1, bias=False),
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
zero_module(nn.Conv2d(inner_dim, dim, 3, 1, 1, bias=False)),
|
||||
)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context):
|
||||
x_in = x
|
||||
x = self.proj_in(x)
|
||||
context = self.proj_context(context)
|
||||
x = self.depth_attn(x, context)
|
||||
x = self.proj_out(x) + x_in
|
||||
return x
|
||||
|
||||
|
||||
class DepthWiseAttention(UNetModel):
|
||||
def __init__(self, volume_dims=(5,16,32,64), *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# num_heads = 4
|
||||
model_channels = kwargs['model_channels']
|
||||
channel_mult = kwargs['channel_mult']
|
||||
d0,d1,d2,d3 = volume_dims
|
||||
|
||||
# 4
|
||||
ch = model_channels*channel_mult[2]
|
||||
self.middle_conditions = DepthTransformer(ch, 4, d3 // 2, context_dim=d3)
|
||||
|
||||
self.output_conditions=nn.ModuleList()
|
||||
self.output_b2c = {3:0,4:1,5:2,6:3,7:4,8:5,9:6,10:7,11:8}
|
||||
# 8
|
||||
ch = model_channels*channel_mult[2]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 0
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d2 // 2, context_dim=d2)) # 1
|
||||
# 16
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 2
|
||||
ch = model_channels*channel_mult[1]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 3
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d1 // 2, context_dim=d1)) # 4
|
||||
# 32
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 5
|
||||
ch = model_channels*channel_mult[0]
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 6
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 7
|
||||
self.output_conditions.append(DepthTransformer(ch, 4, d0 // 2, context_dim=d0)) # 8
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, source_dict=None, **kwargs):
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
|
||||
emb = self.time_embed(t_emb)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for index, module in enumerate(self.input_blocks):
|
||||
h = module(h, emb, context)
|
||||
hs.append(h)
|
||||
|
||||
h = self.middle_block(h, emb, context)
|
||||
h = self.middle_conditions(h, context=source_dict[h.shape[-1]])
|
||||
|
||||
for index, module in enumerate(self.output_blocks):
|
||||
h = torch.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
if index in self.output_b2c:
|
||||
layer = self.output_conditions[self.output_b2c[index]]
|
||||
h = layer(h, context=source_dict[h.shape[-1]])
|
||||
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
def get_trainable_parameters(self):
|
||||
paras = [para for para in self.middle_conditions.parameters()] + [para for para in self.output_conditions.parameters()]
|
||||
return paras
|
||||
@@ -0,0 +1,186 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
class Image2DResBlockWithTV(nn.Module):
|
||||
def __init__(self, dim, tdim, vdim):
|
||||
super().__init__()
|
||||
norm = lambda c: nn.GroupNorm(8, c)
|
||||
self.time_embed = nn.Conv2d(tdim, dim, 1, 1)
|
||||
self.view_embed = nn.Conv2d(vdim, dim, 1, 1)
|
||||
self.conv = nn.Sequential(
|
||||
norm(dim),
|
||||
nn.SiLU(True),
|
||||
nn.Conv2d(dim, dim, 3, 1, 1),
|
||||
norm(dim),
|
||||
nn.SiLU(True),
|
||||
nn.Conv2d(dim, dim, 3, 1, 1),
|
||||
)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
return x+self.conv(x+self.time_embed(t)+self.view_embed(v))
|
||||
|
||||
|
||||
class NoisyTargetViewEncoder(nn.Module):
|
||||
def __init__(self, time_embed_dim, viewpoint_dim, run_dim=16, output_dim=8):
|
||||
super().__init__()
|
||||
|
||||
self.init_conv = nn.Conv2d(4, run_dim, 3, 1, 1)
|
||||
self.out_conv0 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.out_conv1 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.out_conv2 = Image2DResBlockWithTV(run_dim, time_embed_dim, viewpoint_dim)
|
||||
self.final_out = nn.Sequential(
|
||||
nn.GroupNorm(8, run_dim),
|
||||
nn.SiLU(True),
|
||||
nn.Conv2d(run_dim, output_dim, 3, 1, 1)
|
||||
)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
B, DT = t.shape
|
||||
t = t.view(B, DT, 1, 1)
|
||||
B, DV = v.shape
|
||||
v = v.view(B, DV, 1, 1)
|
||||
|
||||
x = self.init_conv(x)
|
||||
x = self.out_conv0(x, t, v)
|
||||
x = self.out_conv1(x, t, v)
|
||||
x = self.out_conv2(x, t, v)
|
||||
x = self.final_out(x)
|
||||
return x
|
||||
|
||||
class SpatialUpTimeBlock(nn.Module):
|
||||
def __init__(self, x_in_dim, t_in_dim, out_dim):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
|
||||
self.norm = norm_act(x_in_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.ConvTranspose3d(x_in_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
|
||||
|
||||
def forward(self, x, t):
|
||||
x = x + self.t_conv(t)
|
||||
return self.conv(self.silu(self.norm(x)))
|
||||
|
||||
class SpatialTimeBlock(nn.Module):
|
||||
def __init__(self, x_in_dim, t_in_dim, out_dim, stride):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_in_dim, x_in_dim, 1, 1) # 16
|
||||
self.bn = norm_act(x_in_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.Conv3d(x_in_dim, out_dim, 3, stride=stride, padding=1)
|
||||
|
||||
def forward(self, x, t):
|
||||
x = x + self.t_conv(t)
|
||||
return self.conv(self.silu(self.bn(x)))
|
||||
|
||||
class SpatialTime3DNet(nn.Module):
|
||||
def __init__(self, time_dim=256, input_dim=128, dims=(32, 64, 128, 256)):
|
||||
super().__init__()
|
||||
d0, d1, d2, d3 = dims
|
||||
dt = time_dim
|
||||
|
||||
self.init_conv = nn.Conv3d(input_dim, d0, 3, 1, 1) # 32
|
||||
self.conv0 = SpatialTimeBlock(d0, dt, d0, stride=1)
|
||||
|
||||
self.conv1 = SpatialTimeBlock(d0, dt, d1, stride=2)
|
||||
self.conv2_0 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
self.conv2_1 = SpatialTimeBlock(d1, dt, d1, stride=1)
|
||||
|
||||
self.conv3 = SpatialTimeBlock(d1, dt, d2, stride=2)
|
||||
self.conv4_0 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
self.conv4_1 = SpatialTimeBlock(d2, dt, d2, stride=1)
|
||||
|
||||
self.conv5 = SpatialTimeBlock(d2, dt, d3, stride=2)
|
||||
self.conv6_0 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
self.conv6_1 = SpatialTimeBlock(d3, dt, d3, stride=1)
|
||||
|
||||
self.conv7 = SpatialUpTimeBlock(d3, dt, d2)
|
||||
self.conv8 = SpatialUpTimeBlock(d2, dt, d1)
|
||||
self.conv9 = SpatialUpTimeBlock(d1, dt, d0)
|
||||
|
||||
def forward(self, x, t):
|
||||
B, C = t.shape
|
||||
t = t.view(B, C, 1, 1, 1)
|
||||
|
||||
x = self.init_conv(x)
|
||||
conv0 = self.conv0(x, t)
|
||||
|
||||
x = self.conv1(conv0, t)
|
||||
x = self.conv2_0(x, t)
|
||||
conv2 = self.conv2_1(x, t)
|
||||
|
||||
x = self.conv3(conv2, t)
|
||||
x = self.conv4_0(x, t)
|
||||
conv4 = self.conv4_1(x, t)
|
||||
|
||||
x = self.conv5(conv4, t)
|
||||
x = self.conv6_0(x, t)
|
||||
x = self.conv6_1(x, t)
|
||||
|
||||
x = conv4 + self.conv7(x, t)
|
||||
x = conv2 + self.conv8(x, t)
|
||||
x = conv0 + self.conv9(x, t)
|
||||
return x
|
||||
|
||||
class FrustumTVBlock(nn.Module):
|
||||
def __init__(self, x_dim, t_dim, v_dim, out_dim, stride):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
self.bn = norm_act(x_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.Conv3d(x_dim, out_dim, 3, stride=stride, padding=1)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
x = x + self.t_conv(t) + self.v_conv(v)
|
||||
return self.conv(self.silu(self.bn(x)))
|
||||
|
||||
class FrustumTVUpBlock(nn.Module):
|
||||
def __init__(self, x_dim, t_dim, v_dim, out_dim):
|
||||
super().__init__()
|
||||
norm_act = lambda c: nn.GroupNorm(8, c)
|
||||
self.t_conv = nn.Conv3d(t_dim, x_dim, 1, 1) # 16
|
||||
self.v_conv = nn.Conv3d(v_dim, x_dim, 1, 1) # 16
|
||||
self.norm = norm_act(x_dim)
|
||||
self.silu = nn.SiLU(True)
|
||||
self.conv = nn.ConvTranspose3d(x_dim, out_dim, kernel_size=3, padding=1, output_padding=1, stride=2)
|
||||
|
||||
def forward(self, x, t, v):
|
||||
x = x + self.t_conv(t) + self.v_conv(v)
|
||||
return self.conv(self.silu(self.norm(x)))
|
||||
|
||||
class FrustumTV3DNet(nn.Module):
|
||||
def __init__(self, in_dim, t_dim, v_dim, dims=(32, 64, 128, 256)):
|
||||
super().__init__()
|
||||
self.conv0 = nn.Conv3d(in_dim, dims[0], 3, 1, 1) # 32
|
||||
|
||||
self.conv1 = FrustumTVBlock(dims[0], t_dim, v_dim, dims[1], 2)
|
||||
self.conv2 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[1], 1)
|
||||
|
||||
self.conv3 = FrustumTVBlock(dims[1], t_dim, v_dim, dims[2], 2)
|
||||
self.conv4 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[2], 1)
|
||||
|
||||
self.conv5 = FrustumTVBlock(dims[2], t_dim, v_dim, dims[3], 2)
|
||||
self.conv6 = FrustumTVBlock(dims[3], t_dim, v_dim, dims[3], 1)
|
||||
|
||||
self.up0 = FrustumTVUpBlock(dims[3], t_dim, v_dim, dims[2])
|
||||
self.up1 = FrustumTVUpBlock(dims[2], t_dim, v_dim, dims[1])
|
||||
self.up2 = FrustumTVUpBlock(dims[1], t_dim, v_dim, dims[0])
|
||||
|
||||
def forward(self, x, t, v):
|
||||
B,DT = t.shape
|
||||
t = t.view(B,DT,1,1,1)
|
||||
B,DV = v.shape
|
||||
v = v.view(B,DV,1,1,1)
|
||||
|
||||
b, _, d, h, w = x.shape
|
||||
x0 = self.conv0(x)
|
||||
x1 = self.conv2(self.conv1(x0, t, v), t, v)
|
||||
x2 = self.conv4(self.conv3(x1, t, v), t, v)
|
||||
x3 = self.conv6(self.conv5(x2, t, v), t, v)
|
||||
|
||||
x2 = self.up0(x3, t, v) + x2
|
||||
x1 = self.up1(x2, t, v) + x1
|
||||
x0 = self.up2(x1, t, v) + x0
|
||||
return {w: x0, w//2: x1, w//4: x2, w//8: x3}
|
||||
@@ -0,0 +1,103 @@
|
||||
import torch
|
||||
from kornia import create_meshgrid
|
||||
|
||||
|
||||
def project_and_normalize(ref_grid, src_proj, length):
|
||||
"""
|
||||
|
||||
@param ref_grid: b 3 n
|
||||
@param src_proj: b 4 4
|
||||
@param length: int
|
||||
@return: b, n, 2
|
||||
"""
|
||||
src_grid = src_proj[:, :3, :3] @ ref_grid + src_proj[:, :3, 3:] # b 3 n
|
||||
div_val = src_grid[:, -1:]
|
||||
div_val[div_val<1e-4] = 1e-4
|
||||
src_grid = src_grid[:, :2] / div_val # divide by depth (b, 2, n)
|
||||
src_grid[:, 0] = src_grid[:, 0]/((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid[:, 1] = src_grid[:, 1]/((length - 1) / 2) - 1 # scale to -1~1
|
||||
src_grid = src_grid.permute(0, 2, 1) # (b, n, 2)
|
||||
return src_grid
|
||||
|
||||
|
||||
def construct_project_matrix(x_ratio, y_ratio, Ks, poses):
|
||||
"""
|
||||
@param x_ratio: float
|
||||
@param y_ratio: float
|
||||
@param Ks: b,3,3
|
||||
@param poses: b,3,4
|
||||
@return:
|
||||
"""
|
||||
rfn = Ks.shape[0]
|
||||
scale_m = torch.tensor([x_ratio, y_ratio, 1.0], dtype=torch.float32, device=Ks.device)
|
||||
scale_m = torch.diag(scale_m)
|
||||
ref_prj = scale_m[None, :, :] @ Ks @ poses # rfn,3,4
|
||||
pad_vals = torch.zeros([rfn, 1, 4], dtype=torch.float32, device=ref_prj.device)
|
||||
pad_vals[:, :, 3] = 1.0
|
||||
ref_prj = torch.cat([ref_prj, pad_vals], 1) # rfn,4,4
|
||||
return ref_prj
|
||||
|
||||
def get_warp_coordinates(volume_xyz, warp_size, input_size, Ks, warp_pose):
|
||||
B, _, D, H, W = volume_xyz.shape
|
||||
ratio = warp_size / input_size
|
||||
warp_proj = construct_project_matrix(ratio, ratio, Ks, warp_pose) # B,4,4
|
||||
warp_coords = project_and_normalize(volume_xyz.view(B,3,D*H*W), warp_proj, warp_size).view(B, D, H, W, 2)
|
||||
return warp_coords
|
||||
|
||||
|
||||
def create_target_volume(depth_size, volume_size, input_image_size, pose_target, K, near=None, far=None):
|
||||
device, dtype = pose_target.device, pose_target.dtype
|
||||
|
||||
# compute a depth range on the unit sphere
|
||||
H, W, D, B = volume_size, volume_size, depth_size, pose_target.shape[0]
|
||||
if near is not None and far is not None :
|
||||
# near, far b,1,h,w
|
||||
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values.view(1, D, 1, 1) # 1,d,1,1
|
||||
depth_values = depth_values * (far - near) + near # b d h w
|
||||
depth_values = depth_values.view(B, 1, D, H * W)
|
||||
else:
|
||||
near, far = near_far_from_unit_sphere_using_camera_poses(pose_target) # b 1
|
||||
depth_values = torch.linspace(0, 1, steps=depth_size).to(near.device).to(near.dtype) # d
|
||||
depth_values = depth_values[None,:,None] * (far[:,None,:] - near[:,None,:]) + near[:,None,:] # b d 1
|
||||
depth_values = depth_values.view(B, 1, D, 1).expand(B, 1, D, H*W)
|
||||
|
||||
ratio = volume_size / input_image_size
|
||||
|
||||
# creat a grid on the target (reference) view
|
||||
# H, W, D, B = volume_size, volume_size, depth_values.shape[1], depth_values.shape[0]
|
||||
|
||||
# creat mesh grid: note reference also means target
|
||||
ref_grid = create_meshgrid(H, W, normalized_coordinates=False) # (1, H, W, 2)
|
||||
ref_grid = ref_grid.to(device).to(dtype)
|
||||
ref_grid = ref_grid.permute(0, 3, 1, 2) # (1, 2, H, W)
|
||||
ref_grid = ref_grid.reshape(1, 2, H*W) # (1, 2, H*W)
|
||||
ref_grid = ref_grid.expand(B, -1, -1) # (B, 2, H*W)
|
||||
ref_grid = torch.cat((ref_grid, torch.ones(B, 1, H*W, dtype=ref_grid.dtype, device=ref_grid.device)), dim=1) # (B, 3, H*W)
|
||||
ref_grid = ref_grid.unsqueeze(2) * depth_values # (B, 3, D, H*W)
|
||||
|
||||
# unproject to space and transfer to world coordinates.
|
||||
Ks = K
|
||||
ref_proj = construct_project_matrix(ratio, ratio, Ks, pose_target) # B,4,4
|
||||
ref_proj_inv = torch.inverse(ref_proj) # B,4,4
|
||||
ref_grid = ref_proj_inv[:,:3,:3] @ ref_grid.view(B,3,D*H*W) + ref_proj_inv[:,:3,3:] # B,3,3 @ B,3,DHW + B,3,1 => B,3,DHW
|
||||
return ref_grid.reshape(B,3,D,H,W), depth_values.view(B,1,D,H,W)
|
||||
|
||||
def near_far_from_unit_sphere_using_camera_poses(camera_poses):
|
||||
"""
|
||||
@param camera_poses: b 3 4
|
||||
@return:
|
||||
near: b,1
|
||||
far: b,1
|
||||
"""
|
||||
R_w2c = camera_poses[..., :3, :3] # b 3 3
|
||||
t_w2c = camera_poses[..., :3, 3:] # b 3 1
|
||||
camera_origin = -R_w2c.permute(0,2,1) @ t_w2c # b 3 1
|
||||
# R_w2c.T @ (0,0,1) = z_dir
|
||||
camera_orient = R_w2c.permute(0,2,1)[...,:3,2:3] # b 3 1
|
||||
camera_origin, camera_orient = camera_origin[...,0], camera_orient[..., 0] # b 3
|
||||
a = torch.sum(camera_orient ** 2, dim=-1, keepdim=True) # b 1
|
||||
b = -torch.sum(camera_orient * camera_origin, dim=-1, keepdim=True) # b 1
|
||||
mid = b / a # b 1
|
||||
near, far = mid - 1.0, mid + 1.0
|
||||
return near, far
|
||||
336
modelscope/models/cv/image_to_3d/ldm/modules/attention.py
Normal file
336
modelscope/models/cv/image_to_3d/ldm/modules/attention.py
Normal file
@@ -0,0 +1,336 @@
|
||||
from inspect import isfunction
|
||||
import math
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, einsum
|
||||
from einops import rearrange, repeat
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import checkpoint
|
||||
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def uniq(arr):
|
||||
return{el: True for el in arr}.keys()
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def max_neg_value(t):
|
||||
return -torch.finfo(t.dtype).max
|
||||
|
||||
|
||||
def init_(tensor):
|
||||
dim = tensor.shape[-1]
|
||||
std = 1 / math.sqrt(dim)
|
||||
tensor.uniform_(-std, std)
|
||||
return tensor
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
# feedforward
|
||||
class ConvGEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim_in, dim_out * 2, 1, 1, 0)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def Normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x)
|
||||
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
|
||||
k = k.softmax(dim=-1)
|
||||
context = torch.einsum('bhdn,bhen->bhde', k, v)
|
||||
out = torch.einsum('bhde,bhdn->bhen', context, q)
|
||||
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class SpatialSelfAttention(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = rearrange(q, 'b c h w -> b (h w) c')
|
||||
k = rearrange(k, 'b c h w -> b c (h w)')
|
||||
w_ = torch.einsum('bij,bjk->bik', q, k)
|
||||
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = rearrange(v, 'b c h w -> b c (h w)')
|
||||
w_ = rearrange(w_, 'b i j -> b j i')
|
||||
h_ = torch.einsum('bij,bjk->bik', v, w_)
|
||||
h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class CrossAttention(nn.Module):
|
||||
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = default(context_dim, query_dim)
|
||||
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(
|
||||
nn.Linear(inner_dim, query_dim),
|
||||
nn.Dropout(dropout)
|
||||
)
|
||||
|
||||
def forward(self, x, context=None, mask=None):
|
||||
h = self.heads
|
||||
|
||||
q = self.to_q(x)
|
||||
context = default(context, x)
|
||||
k = self.to_k(context)
|
||||
v = self.to_v(context)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
||||
|
||||
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
||||
|
||||
if exists(mask):
|
||||
mask = mask>0
|
||||
mask = rearrange(mask, 'b ... -> b (...)')
|
||||
max_neg_value = -torch.finfo(sim.dtype).max
|
||||
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
||||
sim.masked_fill_(~mask, max_neg_value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum('b i j, b j d -> b i d', attn, v)
|
||||
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
|
||||
return self.to_out(out)
|
||||
|
||||
class BasicSpatialTransformer(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, context_dim=None, checkpoint=True):
|
||||
super().__init__()
|
||||
inner_dim = n_heads * d_head
|
||||
self.proj_in = nn.Sequential(
|
||||
nn.GroupNorm(8, dim),
|
||||
nn.Conv2d(dim, inner_dim, kernel_size=1, stride=1, padding=0),
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
)
|
||||
self.attn = CrossAttention(query_dim=inner_dim, heads=n_heads, dim_head=d_head, context_dim=context_dim) # is a self-attention if not self.disable_self_attn
|
||||
self.out_conv = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(inner_dim, inner_dim, 1, 1),
|
||||
)
|
||||
self.proj_out = nn.Sequential(
|
||||
nn.GroupNorm(8, inner_dim),
|
||||
nn.ReLU(True),
|
||||
zero_module(nn.Conv2d(inner_dim, dim, kernel_size=1, stride=1, padding=0)),
|
||||
)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context):
|
||||
# input
|
||||
b,_,h,w = x.shape
|
||||
x_in = x
|
||||
x = self.proj_in(x)
|
||||
|
||||
# attention
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
context = rearrange(context, 'b c h w -> b (h w) c').contiguous()
|
||||
x = self.attn(x, context) + x
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
|
||||
# output
|
||||
x = self.out_conv(x) + x
|
||||
x = self.proj_out(x) + x_in
|
||||
return x
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.disable_self_attn = disable_self_attn
|
||||
self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout,
|
||||
context_dim=context_dim if self.disable_self_attn else None) # is a self-attention if not self.disable_self_attn
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
|
||||
heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def forward(self, x, context=None):
|
||||
return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
|
||||
|
||||
def _forward(self, x, context=None):
|
||||
x = self.attn1(self.norm1(x), context=context if self.disable_self_attn else None) + x
|
||||
x = self.attn2(self.norm2(x), context=context) + x
|
||||
x = self.ff(self.norm3(x)) + x
|
||||
return x
|
||||
|
||||
class ConvFeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Conv2d(dim, inner_dim, 1, 1, 0),
|
||||
nn.GELU()
|
||||
) if not glu else ConvGEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Conv2d(inner_dim, dim_out, 1, 1, 0)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data.
|
||||
First, project the input (aka embedding)
|
||||
and reshape to b, t, d.
|
||||
Then apply standard transformer action.
|
||||
Finally, reshape to image
|
||||
"""
|
||||
def __init__(self, in_channels, n_heads, d_head,
|
||||
depth=1, dropout=0., context_dim=None,
|
||||
disable_self_attn=False):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = Normalize(in_channels)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels,
|
||||
inner_dim,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim,
|
||||
disable_self_attn=disable_self_attn)
|
||||
for d in range(depth)]
|
||||
)
|
||||
|
||||
self.proj_out = zero_module(nn.Conv2d(inner_dim,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0))
|
||||
|
||||
def forward(self, x, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
b, c, h, w = x.shape
|
||||
x_in = x
|
||||
x = self.norm(x)
|
||||
x = self.proj_in(x)
|
||||
x = rearrange(x, 'b c h w -> b (h w) c').contiguous()
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, context=context)
|
||||
x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous()
|
||||
x = self.proj_out(x)
|
||||
return x + x_in
|
||||
@@ -0,0 +1,835 @@
|
||||
# pytorch_diffusion + derived encoder decoder
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import rearrange
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import LinearAttention
|
||||
|
||||
|
||||
def get_timestep_embedding(timesteps, embedding_dim):
|
||||
"""
|
||||
This matches the implementation in Denoising Diffusion Probabilistic Models:
|
||||
From Fairseq.
|
||||
Build sinusoidal embeddings.
|
||||
This matches the implementation in tensor2tensor, but differs slightly
|
||||
from the description in Section 3.5 of "Attention Is All You Need".
|
||||
"""
|
||||
assert len(timesteps.shape) == 1
|
||||
|
||||
half_dim = embedding_dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb)
|
||||
emb = emb.to(device=timesteps.device)
|
||||
emb = timesteps.float()[:, None] * emb[None, :]
|
||||
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
|
||||
if embedding_dim % 2 == 1: # zero pad
|
||||
emb = torch.nn.functional.pad(emb, (0,1,0,0))
|
||||
return emb
|
||||
|
||||
|
||||
def nonlinearity(x):
|
||||
# swish
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
def Normalize(in_channels, num_groups=32):
|
||||
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
if self.with_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels, with_conv):
|
||||
super().__init__()
|
||||
self.with_conv = with_conv
|
||||
if self.with_conv:
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
if self.with_conv:
|
||||
pad = (0,1,0,1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
else:
|
||||
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
||||
dropout, temb_channels=512):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
out_channels = in_channels if out_channels is None else out_channels
|
||||
self.out_channels = out_channels
|
||||
self.use_conv_shortcut = conv_shortcut
|
||||
|
||||
self.norm1 = Normalize(in_channels)
|
||||
self.conv1 = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if temb_channels > 0:
|
||||
self.temb_proj = torch.nn.Linear(temb_channels,
|
||||
out_channels)
|
||||
self.norm2 = Normalize(out_channels)
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.conv2 = torch.nn.Conv2d(out_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
else:
|
||||
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
x = self.conv_shortcut(x)
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x+h
|
||||
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
def __init__(self, in_channels):
|
||||
super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = Normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.k = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.v = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.proj_out = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b,c,h,w = q.shape
|
||||
q = q.reshape(b,c,h*w)
|
||||
q = q.permute(0,2,1) # b,hw,c
|
||||
k = k.reshape(b,c,h*w) # b,c,hw
|
||||
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = torch.nn.functional.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b,c,h*w)
|
||||
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
|
||||
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
||||
h_ = h_.reshape(b,c,h,w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_type="vanilla"):
|
||||
assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown'
|
||||
print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
|
||||
if attn_type == "vanilla":
|
||||
return AttnBlock(in_channels)
|
||||
elif attn_type == "none":
|
||||
return nn.Identity(in_channels)
|
||||
else:
|
||||
return LinAttnBlock(in_channels)
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla"):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = self.ch*4
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.use_timestep = use_timestep
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
self.temb = nn.Module()
|
||||
self.temb.dense = nn.ModuleList([
|
||||
torch.nn.Linear(self.ch,
|
||||
self.temb_ch),
|
||||
torch.nn.Linear(self.temb_ch,
|
||||
self.temb_ch),
|
||||
])
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
skip_in = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
if i_block == self.num_res_blocks:
|
||||
skip_in = ch*in_ch_mult[i_level]
|
||||
block.append(ResnetBlock(in_channels=block_in+skip_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, t=None, context=None):
|
||||
#assert x.shape[2] == x.shape[3] == self.resolution
|
||||
if context is not None:
|
||||
# assume aligned context, cat along channel axis
|
||||
x = torch.cat((x, context), dim=1)
|
||||
if self.use_timestep:
|
||||
# timestep embedding
|
||||
assert t is not None
|
||||
temb = get_timestep_embedding(t, self.ch)
|
||||
temb = self.temb.dense[0](temb)
|
||||
temb = nonlinearity(temb)
|
||||
temb = self.temb.dense[1](temb)
|
||||
else:
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](
|
||||
torch.cat([h, hs.pop()], dim=1), temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
def get_last_layer(self):
|
||||
return self.conv_out.weight
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla",
|
||||
**ignore_kwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
|
||||
# downsampling
|
||||
self.conv_in = torch.nn.Conv2d(in_channels,
|
||||
self.ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
curr_res = resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
self.in_ch_mult = in_ch_mult
|
||||
self.down = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_in = ch*in_ch_mult[i_level]
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
down = nn.Module()
|
||||
down.block = block
|
||||
down.attn = attn
|
||||
if i_level != self.num_resolutions-1:
|
||||
down.downsample = Downsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res // 2
|
||||
self.down.append(down)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
2*z_channels if double_z else z_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# downsampling
|
||||
hs = [self.conv_in(x)]
|
||||
for i_level in range(self.num_resolutions):
|
||||
for i_block in range(self.num_res_blocks):
|
||||
h = self.down[i_level].block[i_block](hs[-1], temb)
|
||||
if len(self.down[i_level].attn) > 0:
|
||||
h = self.down[i_level].attn[i_block](h)
|
||||
hs.append(h)
|
||||
if i_level != self.num_resolutions-1:
|
||||
hs.append(self.down[i_level].downsample(hs[-1]))
|
||||
|
||||
# middle
|
||||
h = hs[-1]
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels,
|
||||
resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False,
|
||||
attn_type="vanilla", **ignorekwargs):
|
||||
super().__init__()
|
||||
if use_linear_attn: attn_type = "linear"
|
||||
self.ch = ch
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.in_channels = in_channels
|
||||
self.give_pre_end = give_pre_end
|
||||
self.tanh_out = tanh_out
|
||||
|
||||
# compute in_ch_mult, block_in and curr_res at lowest res
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
block_in = ch*ch_mult[self.num_resolutions-1]
|
||||
curr_res = resolution // 2**(self.num_resolutions-1)
|
||||
self.z_shape = (1,z_channels,curr_res,curr_res)
|
||||
print("Working with z of shape {} = {} dimensions.".format(
|
||||
self.z_shape, np.prod(self.z_shape)))
|
||||
|
||||
# z to block_in
|
||||
self.conv_in = torch.nn.Conv2d(z_channels,
|
||||
block_in,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
# middle
|
||||
self.mid = nn.Module()
|
||||
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
self.mid.attn_1 = make_attn(block_in, attn_type=attn_type)
|
||||
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_in,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout)
|
||||
|
||||
# upsampling
|
||||
self.up = nn.ModuleList()
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
block = nn.ModuleList()
|
||||
attn = nn.ModuleList()
|
||||
block_out = ch*ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
if curr_res in attn_resolutions:
|
||||
attn.append(make_attn(block_in, attn_type=attn_type))
|
||||
up = nn.Module()
|
||||
up.block = block
|
||||
up.attn = attn
|
||||
if i_level != 0:
|
||||
up.upsample = Upsample(block_in, resamp_with_conv)
|
||||
curr_res = curr_res * 2
|
||||
self.up.insert(0, up) # prepend to get consistent order
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_ch,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, z):
|
||||
#assert z.shape[1:] == self.z_shape[1:]
|
||||
self.last_z_shape = z.shape
|
||||
|
||||
# timestep embedding
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# upsampling
|
||||
for i_level in reversed(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
class SimpleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, *args, **kwargs):
|
||||
super().__init__()
|
||||
self.model = nn.ModuleList([nn.Conv2d(in_channels, in_channels, 1),
|
||||
ResnetBlock(in_channels=in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=2 * in_channels,
|
||||
out_channels=4 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
ResnetBlock(in_channels=4 * in_channels,
|
||||
out_channels=2 * in_channels,
|
||||
temb_channels=0, dropout=0.0),
|
||||
nn.Conv2d(2*in_channels, in_channels, 1),
|
||||
Upsample(in_channels, with_conv=True)])
|
||||
# end
|
||||
self.norm_out = Normalize(in_channels)
|
||||
self.conv_out = torch.nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
for i, layer in enumerate(self.model):
|
||||
if i in [1,2,3]:
|
||||
x = layer(x, None)
|
||||
else:
|
||||
x = layer(x)
|
||||
|
||||
h = self.norm_out(x)
|
||||
h = nonlinearity(h)
|
||||
x = self.conv_out(h)
|
||||
return x
|
||||
|
||||
|
||||
class UpsampleDecoder(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, ch, num_res_blocks, resolution,
|
||||
ch_mult=(2,2), dropout=0.0):
|
||||
super().__init__()
|
||||
# upsampling
|
||||
self.temb_ch = 0
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
block_in = in_channels
|
||||
curr_res = resolution // 2 ** (self.num_resolutions - 1)
|
||||
self.res_blocks = nn.ModuleList()
|
||||
self.upsample_blocks = nn.ModuleList()
|
||||
for i_level in range(self.num_resolutions):
|
||||
res_block = []
|
||||
block_out = ch * ch_mult[i_level]
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
res_block.append(ResnetBlock(in_channels=block_in,
|
||||
out_channels=block_out,
|
||||
temb_channels=self.temb_ch,
|
||||
dropout=dropout))
|
||||
block_in = block_out
|
||||
self.res_blocks.append(nn.ModuleList(res_block))
|
||||
if i_level != self.num_resolutions - 1:
|
||||
self.upsample_blocks.append(Upsample(block_in, True))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
# end
|
||||
self.norm_out = Normalize(block_in)
|
||||
self.conv_out = torch.nn.Conv2d(block_in,
|
||||
out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
# upsampling
|
||||
h = x
|
||||
for k, i_level in enumerate(range(self.num_resolutions)):
|
||||
for i_block in range(self.num_res_blocks + 1):
|
||||
h = self.res_blocks[i_level][i_block](h, None)
|
||||
if i_level != self.num_resolutions - 1:
|
||||
h = self.upsample_blocks[k](h)
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
||||
class LatentRescaler(nn.Module):
|
||||
def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2):
|
||||
super().__init__()
|
||||
# residual block, interpolate, residual block
|
||||
self.factor = factor
|
||||
self.conv_in = nn.Conv2d(in_channels,
|
||||
mid_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1)
|
||||
self.res_block1 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
self.attn = AttnBlock(mid_channels)
|
||||
self.res_block2 = nn.ModuleList([ResnetBlock(in_channels=mid_channels,
|
||||
out_channels=mid_channels,
|
||||
temb_channels=0,
|
||||
dropout=0.0) for _ in range(depth)])
|
||||
|
||||
self.conv_out = nn.Conv2d(mid_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_in(x)
|
||||
for block in self.res_block1:
|
||||
x = block(x, None)
|
||||
x = torch.nn.functional.interpolate(x, size=(int(round(x.shape[2]*self.factor)), int(round(x.shape[3]*self.factor))))
|
||||
x = self.attn(x)
|
||||
for block in self.res_block2:
|
||||
x = block(x, None)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleEncoder(nn.Module):
|
||||
def __init__(self, in_channels, ch, resolution, out_ch, num_res_blocks,
|
||||
attn_resolutions, dropout=0.0, resamp_with_conv=True,
|
||||
ch_mult=(1,2,4,8), rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
intermediate_chn = ch * ch_mult[-1]
|
||||
self.encoder = Encoder(in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult,
|
||||
z_channels=intermediate_chn, double_z=False, resolution=resolution,
|
||||
attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv,
|
||||
out_ch=None)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=intermediate_chn,
|
||||
mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
x = self.rescaler(x)
|
||||
return x
|
||||
|
||||
|
||||
class MergedRescaleDecoder(nn.Module):
|
||||
def __init__(self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1,2,4,8),
|
||||
dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1):
|
||||
super().__init__()
|
||||
tmp_chn = z_channels*ch_mult[-1]
|
||||
self.decoder = Decoder(out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout,
|
||||
resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks,
|
||||
ch_mult=ch_mult, resolution=resolution, ch=ch)
|
||||
self.rescaler = LatentRescaler(factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn,
|
||||
out_channels=tmp_chn, depth=rescale_module_depth)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsampler(nn.Module):
|
||||
def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2):
|
||||
super().__init__()
|
||||
assert out_size >= in_size
|
||||
num_blocks = int(np.log2(out_size//in_size))+1
|
||||
factor_up = 1.+ (out_size % in_size)
|
||||
print(f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}")
|
||||
self.rescaler = LatentRescaler(factor=factor_up, in_channels=in_channels, mid_channels=2*in_channels,
|
||||
out_channels=in_channels)
|
||||
self.decoder = Decoder(out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2,
|
||||
attn_resolutions=[], in_channels=None, ch=in_channels,
|
||||
ch_mult=[ch_mult for _ in range(num_blocks)])
|
||||
|
||||
def forward(self, x):
|
||||
x = self.rescaler(x)
|
||||
x = self.decoder(x)
|
||||
return x
|
||||
|
||||
|
||||
class Resize(nn.Module):
|
||||
def __init__(self, in_channels=None, learned=False, mode="bilinear"):
|
||||
super().__init__()
|
||||
self.with_conv = learned
|
||||
self.mode = mode
|
||||
if self.with_conv:
|
||||
print(f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode")
|
||||
raise NotImplementedError()
|
||||
assert in_channels is not None
|
||||
# no asymmetric padding in torch conv, must do it ourselves
|
||||
self.conv = torch.nn.Conv2d(in_channels,
|
||||
in_channels,
|
||||
kernel_size=4,
|
||||
stride=2,
|
||||
padding=1)
|
||||
|
||||
def forward(self, x, scale_factor=1.0):
|
||||
if scale_factor==1.0:
|
||||
return x
|
||||
else:
|
||||
x = torch.nn.functional.interpolate(x, mode=self.mode, align_corners=False, scale_factor=scale_factor)
|
||||
return x
|
||||
|
||||
class FirstStagePostProcessor(nn.Module):
|
||||
|
||||
def __init__(self, ch_mult:list, in_channels,
|
||||
pretrained_model:nn.Module=None,
|
||||
reshape=False,
|
||||
n_channels=None,
|
||||
dropout=0.,
|
||||
pretrained_config=None):
|
||||
super().__init__()
|
||||
if pretrained_config is None:
|
||||
assert pretrained_model is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.pretrained_model = pretrained_model
|
||||
else:
|
||||
assert pretrained_config is not None, 'Either "pretrained_model" or "pretrained_config" must not be None'
|
||||
self.instantiate_pretrained(pretrained_config)
|
||||
|
||||
self.do_reshape = reshape
|
||||
|
||||
if n_channels is None:
|
||||
n_channels = self.pretrained_model.encoder.ch
|
||||
|
||||
self.proj_norm = Normalize(in_channels,num_groups=in_channels//2)
|
||||
self.proj = nn.Conv2d(in_channels,n_channels,kernel_size=3,
|
||||
stride=1,padding=1)
|
||||
|
||||
blocks = []
|
||||
downs = []
|
||||
ch_in = n_channels
|
||||
for m in ch_mult:
|
||||
blocks.append(ResnetBlock(in_channels=ch_in,out_channels=m*n_channels,dropout=dropout))
|
||||
ch_in = m * n_channels
|
||||
downs.append(Downsample(ch_in, with_conv=False))
|
||||
|
||||
self.model = nn.ModuleList(blocks)
|
||||
self.downsampler = nn.ModuleList(downs)
|
||||
|
||||
|
||||
def instantiate_pretrained(self, config):
|
||||
model = instantiate_from_config(config)
|
||||
self.pretrained_model = model.eval()
|
||||
# self.pretrained_model.train = False
|
||||
for param in self.pretrained_model.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_with_pretrained(self,x):
|
||||
c = self.pretrained_model.encode(x)
|
||||
if isinstance(c, DiagonalGaussianDistribution):
|
||||
c = c.mode()
|
||||
return c
|
||||
|
||||
def forward(self,x):
|
||||
z_fs = self.encode_with_pretrained(x)
|
||||
z = self.proj_norm(z_fs)
|
||||
z = self.proj(z)
|
||||
z = nonlinearity(z)
|
||||
|
||||
for submodel, downmodel in zip(self.model,self.downsampler):
|
||||
z = submodel(z,temb=None)
|
||||
z = downmodel(z)
|
||||
|
||||
if self.do_reshape:
|
||||
z = rearrange(z,'b c h w -> b (h w) c')
|
||||
return z
|
||||
|
||||
@@ -0,0 +1,996 @@
|
||||
from abc import abstractmethod
|
||||
from functools import partial
|
||||
import math
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
import torch as th
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import (
|
||||
checkpoint,
|
||||
conv_nd,
|
||||
linear,
|
||||
avg_pool_nd,
|
||||
zero_module,
|
||||
normalization,
|
||||
timestep_embedding,
|
||||
)
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.attention import SpatialTransformer
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import exists
|
||||
|
||||
|
||||
# dummy replace
|
||||
def convert_module_to_f16(x):
|
||||
pass
|
||||
|
||||
def convert_module_to_f32(x):
|
||||
pass
|
||||
|
||||
|
||||
## go
|
||||
class AttentionPool2d(nn.Module):
|
||||
"""
|
||||
Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spacial_dim: int,
|
||||
embed_dim: int,
|
||||
num_heads_channels: int,
|
||||
output_dim: int = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(th.randn(embed_dim, spacial_dim ** 2 + 1) / embed_dim ** 0.5)
|
||||
self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1)
|
||||
self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1)
|
||||
self.num_heads = embed_dim // num_heads_channels
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, *_spatial = x.shape
|
||||
x = x.reshape(b, c, -1) # NC(HW)
|
||||
x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1)
|
||||
x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1)
|
||||
x = self.qkv_proj(x)
|
||||
x = self.attention(x)
|
||||
x = self.c_proj(x)
|
||||
return x[:, :, 0]
|
||||
|
||||
|
||||
class TimestepBlock(nn.Module):
|
||||
"""
|
||||
Any module where forward() takes timestep embeddings as a second argument.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the module to `x` given `emb` timestep embeddings.
|
||||
"""
|
||||
|
||||
|
||||
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
|
||||
"""
|
||||
A sequential module that passes timestep embeddings to the children that
|
||||
support it as an extra input.
|
||||
"""
|
||||
|
||||
def forward(self, x, emb, context=None):
|
||||
for layer in self:
|
||||
if isinstance(layer, TimestepBlock):
|
||||
x = layer(x, emb)
|
||||
elif isinstance(layer, SpatialTransformer):
|
||||
x = layer(x, context)
|
||||
else:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
"""
|
||||
An upsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
upsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
if use_conv:
|
||||
self.conv = conv_nd(dims, self.channels, self.out_channels, 3, padding=padding)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
if self.dims == 3:
|
||||
x = F.interpolate(
|
||||
x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
|
||||
)
|
||||
else:
|
||||
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
||||
if self.use_conv:
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
class TransposedUpsample(nn.Module):
|
||||
'Learned 2x upsampling without padding'
|
||||
def __init__(self, channels, out_channels=None, ks=5):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
|
||||
self.up = nn.ConvTranspose2d(self.channels,self.out_channels,kernel_size=ks,stride=2)
|
||||
|
||||
def forward(self,x):
|
||||
return self.up(x)
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
"""
|
||||
A downsampling layer with an optional convolution.
|
||||
:param channels: channels in the inputs and outputs.
|
||||
:param use_conv: a bool determining if a convolution is applied.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
|
||||
downsampling occurs in the inner-two dimensions.
|
||||
"""
|
||||
|
||||
def __init__(self, channels, use_conv, dims=2, out_channels=None,padding=1):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.dims = dims
|
||||
stride = 2 if dims != 3 else (1, 2, 2)
|
||||
if use_conv:
|
||||
self.op = conv_nd(
|
||||
dims, self.channels, self.out_channels, 3, stride=stride, padding=padding
|
||||
)
|
||||
else:
|
||||
assert self.channels == self.out_channels
|
||||
self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.channels
|
||||
return self.op(x)
|
||||
|
||||
|
||||
class ResBlock(TimestepBlock):
|
||||
"""
|
||||
A residual block that can optionally change the number of channels.
|
||||
:param channels: the number of input channels.
|
||||
:param emb_channels: the number of timestep embedding channels.
|
||||
:param dropout: the rate of dropout.
|
||||
:param out_channels: if specified, the number of out channels.
|
||||
:param use_conv: if True and out_channels is specified, use a spatial
|
||||
convolution instead of a smaller 1x1 convolution to change the
|
||||
channels in the skip connection.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param use_checkpoint: if True, use gradient checkpointing on this module.
|
||||
:param up: if True, use this block for upsampling.
|
||||
:param down: if True, use this block for downsampling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
up=False,
|
||||
down=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
conv_nd(dims, channels, self.out_channels, 3, padding=1),
|
||||
)
|
||||
|
||||
self.updown = up or down
|
||||
|
||||
if up:
|
||||
self.h_upd = Upsample(channels, False, dims)
|
||||
self.x_upd = Upsample(channels, False, dims)
|
||||
elif down:
|
||||
self.h_upd = Downsample(channels, False, dims)
|
||||
self.x_upd = Downsample(channels, False, dims)
|
||||
else:
|
||||
self.h_upd = self.x_upd = nn.Identity()
|
||||
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
linear(
|
||||
emb_channels,
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels,
|
||||
),
|
||||
)
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = conv_nd(
|
||||
dims, channels, self.out_channels, 3, padding=1
|
||||
)
|
||||
else:
|
||||
self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
"""
|
||||
Apply the block to a Tensor, conditioned on a timestep embedding.
|
||||
:param x: an [N x C x ...] Tensor of features.
|
||||
:param emb: an [N x emb_channels] Tensor of timestep embeddings.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
return checkpoint(
|
||||
self._forward, (x, emb), self.parameters(), self.use_checkpoint
|
||||
)
|
||||
|
||||
|
||||
def _forward(self, x, emb):
|
||||
if self.updown:
|
||||
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
||||
h = in_rest(x)
|
||||
h = self.h_upd(h)
|
||||
x = self.x_upd(x)
|
||||
h = in_conv(h)
|
||||
else:
|
||||
h = self.in_layers(x)
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm: # False
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = th.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other.
|
||||
Originally ported from here, but adapted to the N-d case.
|
||||
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
use_checkpoint=False,
|
||||
use_new_attention_order=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
if num_head_channels == -1:
|
||||
self.num_heads = num_heads
|
||||
else:
|
||||
assert (
|
||||
channels % num_head_channels == 0
|
||||
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
|
||||
self.num_heads = channels // num_head_channels
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.norm = normalization(channels)
|
||||
self.qkv = conv_nd(1, channels, channels * 3, 1)
|
||||
if use_new_attention_order:
|
||||
# split qkv before split heads
|
||||
self.attention = QKVAttention(self.num_heads)
|
||||
else:
|
||||
# split heads before split qkv
|
||||
self.attention = QKVAttentionLegacy(self.num_heads)
|
||||
|
||||
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
|
||||
|
||||
def forward(self, x):
|
||||
return checkpoint(self._forward, (x,), self.parameters(), True) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!!
|
||||
#return pt_checkpoint(self._forward, x) # pytorch
|
||||
|
||||
def _forward(self, x):
|
||||
b, c, *spatial = x.shape
|
||||
x = x.reshape(b, c, -1)
|
||||
qkv = self.qkv(self.norm(x))
|
||||
h = self.attention(qkv)
|
||||
h = self.proj_out(h)
|
||||
return (x + h).reshape(b, c, *spatial)
|
||||
|
||||
|
||||
def count_flops_attn(model, _x, y):
|
||||
"""
|
||||
A counter for the `thop` package to count the operations in an
|
||||
attention operation.
|
||||
Meant to be used like:
|
||||
macs, params = thop.profile(
|
||||
model,
|
||||
inputs=(inputs, timestamps),
|
||||
custom_ops={QKVAttention: QKVAttention.count_flops},
|
||||
)
|
||||
"""
|
||||
b, c, *spatial = y[0].shape
|
||||
num_spatial = int(np.prod(spatial))
|
||||
# We perform two matmuls with the same number of ops.
|
||||
# The first computes the weight matrix, the second computes
|
||||
# the combination of the value vectors.
|
||||
matmul_ops = 2 * b * (num_spatial ** 2) * c
|
||||
model.total_ops += th.DoubleTensor([matmul_ops])
|
||||
|
||||
|
||||
class QKVAttentionLegacy(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts", q * scale, k * scale
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v)
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class QKVAttention(nn.Module):
|
||||
"""
|
||||
A module which performs QKV attention and splits in a different order.
|
||||
"""
|
||||
|
||||
def __init__(self, n_heads):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
|
||||
def forward(self, qkv):
|
||||
"""
|
||||
Apply QKV attention.
|
||||
:param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs.
|
||||
:return: an [N x (H * C) x T] tensor after attention.
|
||||
"""
|
||||
bs, width, length = qkv.shape
|
||||
assert width % (3 * self.n_heads) == 0
|
||||
ch = width // (3 * self.n_heads)
|
||||
q, k, v = qkv.chunk(3, dim=1)
|
||||
scale = 1 / math.sqrt(math.sqrt(ch))
|
||||
weight = th.einsum(
|
||||
"bct,bcs->bts",
|
||||
(q * scale).view(bs * self.n_heads, ch, length),
|
||||
(k * scale).view(bs * self.n_heads, ch, length),
|
||||
) # More stable with f16 than dividing afterwards
|
||||
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
|
||||
a = th.einsum("bts,bcs->bct", weight, v.reshape(bs * self.n_heads, ch, length))
|
||||
return a.reshape(bs, -1, length)
|
||||
|
||||
@staticmethod
|
||||
def count_flops(model, _x, y):
|
||||
return count_flops_attn(model, _x, y)
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""
|
||||
The full UNet model with attention and timestep embedding.
|
||||
:param in_channels: channels in the input Tensor.
|
||||
:param model_channels: base channel count for the model.
|
||||
:param out_channels: channels in the output Tensor.
|
||||
:param num_res_blocks: number of residual blocks per downsample.
|
||||
:param attention_resolutions: a collection of downsample rates at which
|
||||
attention will take place. May be a set, list, or tuple.
|
||||
For example, if this contains 4, then at 4x downsampling, attention
|
||||
will be used.
|
||||
:param dropout: the dropout probability.
|
||||
:param channel_mult: channel multiplier for each level of the UNet.
|
||||
:param conv_resample: if True, use learned convolutions for upsampling and
|
||||
downsampling.
|
||||
:param dims: determines if the signal is 1D, 2D, or 3D.
|
||||
:param num_classes: if specified (as an int), then this model will be
|
||||
class-conditional with `num_classes` classes.
|
||||
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
|
||||
:param num_heads: the number of attention heads in each attention layer.
|
||||
:param num_heads_channels: if specified, ignore num_heads and instead use
|
||||
a fixed channel width per attention head.
|
||||
:param num_heads_upsample: works with num_heads to set a different number
|
||||
of heads for upsampling. Deprecated.
|
||||
:param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
|
||||
:param resblock_updown: use residual blocks for up/downsampling.
|
||||
:param use_new_attention_order: use a different attention pattern for potentially
|
||||
increased efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
num_classes=None,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=-1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
use_spatial_transformer=False, # custom transformer support
|
||||
transformer_depth=1, # custom transformer support
|
||||
context_dim=None, # custom transformer support
|
||||
n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
|
||||
legacy=True,
|
||||
disable_self_attentions=None,
|
||||
num_attention_blocks=None
|
||||
):
|
||||
super().__init__()
|
||||
if use_spatial_transformer:
|
||||
assert context_dim is not None, 'Fool!! You forgot to include the dimension of your cross-attention conditioning...'
|
||||
|
||||
if context_dim is not None:
|
||||
assert use_spatial_transformer, 'Fool!! You forgot to use the spatial transformer for your cross-attention conditioning...'
|
||||
from omegaconf.listconfig import ListConfig
|
||||
if type(context_dim) == ListConfig:
|
||||
context_dim = list(context_dim)
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
if num_heads == -1:
|
||||
assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
if num_head_channels == -1:
|
||||
assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
|
||||
|
||||
self.image_size = image_size
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
if isinstance(num_res_blocks, int):
|
||||
self.num_res_blocks = len(channel_mult) * [num_res_blocks]
|
||||
else:
|
||||
if len(num_res_blocks) != len(channel_mult):
|
||||
raise ValueError("provide num_res_blocks either as an int (globally constant) or "
|
||||
"as a list/tuple (per-level) with the same length as channel_mult")
|
||||
self.num_res_blocks = num_res_blocks
|
||||
#self.num_res_blocks = num_res_blocks
|
||||
if disable_self_attentions is not None:
|
||||
# should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
|
||||
assert len(disable_self_attentions) == len(channel_mult)
|
||||
if num_attention_blocks is not None:
|
||||
assert len(num_attention_blocks) == len(self.num_res_blocks)
|
||||
assert all(map(lambda i: self.num_res_blocks[i] >= num_attention_blocks[i], range(len(num_attention_blocks))))
|
||||
print(f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
|
||||
f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
|
||||
f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
|
||||
f"attention will still not be set.") # todo: convert to warning
|
||||
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.num_classes = num_classes
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
self.predict_codebook_ids = n_embed is not None
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
if self.num_classes is not None:
|
||||
self.label_emb = nn.Embedding(num_classes, time_embed_dim)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
) # 0
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for nr in range(self.num_res_blocks[level]):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions: # always True
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or nr < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer( # always uses a self-attn
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
|
||||
self.output_blocks = nn.ModuleList([])
|
||||
for level, mult in list(enumerate(channel_mult))[::-1]:
|
||||
for i in range(self.num_res_blocks[level] + 1):
|
||||
ich = input_block_chans.pop()
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch + ich,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=model_channels * mult,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = model_channels * mult
|
||||
if ds in attention_resolutions:
|
||||
if num_head_channels == -1:
|
||||
dim_head = ch // num_heads
|
||||
else:
|
||||
num_heads = ch // num_head_channels
|
||||
dim_head = num_head_channels
|
||||
if legacy:
|
||||
#num_heads = 1
|
||||
dim_head = ch // num_heads if use_spatial_transformer else num_head_channels
|
||||
if exists(disable_self_attentions):
|
||||
disabled_sa = disable_self_attentions[level]
|
||||
else:
|
||||
disabled_sa = False
|
||||
|
||||
if not exists(num_attention_blocks) or i < num_attention_blocks[level]:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads_upsample,
|
||||
num_head_channels=dim_head,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
) if not use_spatial_transformer else SpatialTransformer(
|
||||
ch, num_heads, dim_head, depth=transformer_depth, context_dim=context_dim,
|
||||
disable_self_attn=disabled_sa
|
||||
)
|
||||
)
|
||||
if level and i == self.num_res_blocks[level]:
|
||||
out_ch = ch
|
||||
layers.append(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
up=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
|
||||
)
|
||||
ds //= 2
|
||||
self.output_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
|
||||
)
|
||||
if self.predict_codebook_ids:
|
||||
self.id_predictor = nn.Sequential(
|
||||
normalization(ch),
|
||||
conv_nd(dims, model_channels, n_embed, 1),
|
||||
#nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
|
||||
)
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
self.output_blocks.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
self.output_blocks.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps=None, context=None, y=None,**kwargs):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:param context: conditioning plugged in via crossattn
|
||||
:param y: an [N] Tensor of labels, if class-conditional.
|
||||
:return: an [N x C x ...] Tensor of outputs.
|
||||
"""
|
||||
assert (y is not None) == (
|
||||
self.num_classes is not None
|
||||
), "must specify y if and only if the model is class-conditional"
|
||||
hs = []
|
||||
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) # N
|
||||
emb = self.time_embed(t_emb) #
|
||||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape == (x.shape[0],)
|
||||
emb = emb + self.label_emb(y)
|
||||
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb, context) # conv
|
||||
hs.append(h)
|
||||
h = self.middle_block(h, emb, context)
|
||||
for module in self.output_blocks:
|
||||
h = th.cat([h, hs.pop()], dim=1)
|
||||
h = module(h, emb, context)
|
||||
h = h.type(x.dtype)
|
||||
if self.predict_codebook_ids:
|
||||
return self.id_predictor(h)
|
||||
else:
|
||||
return self.out(h)
|
||||
|
||||
|
||||
class EncoderUNetModel(nn.Module):
|
||||
"""
|
||||
The half UNet model with attention and timestep embedding.
|
||||
For usage, see UNet.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
in_channels,
|
||||
model_channels,
|
||||
out_channels,
|
||||
num_res_blocks,
|
||||
attention_resolutions,
|
||||
dropout=0,
|
||||
channel_mult=(1, 2, 4, 8),
|
||||
conv_resample=True,
|
||||
dims=2,
|
||||
use_checkpoint=False,
|
||||
use_fp16=False,
|
||||
num_heads=1,
|
||||
num_head_channels=-1,
|
||||
num_heads_upsample=-1,
|
||||
use_scale_shift_norm=False,
|
||||
resblock_updown=False,
|
||||
use_new_attention_order=False,
|
||||
pool="adaptive",
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if num_heads_upsample == -1:
|
||||
num_heads_upsample = num_heads
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.model_channels = model_channels
|
||||
self.out_channels = out_channels
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.attention_resolutions = attention_resolutions
|
||||
self.dropout = dropout
|
||||
self.channel_mult = channel_mult
|
||||
self.conv_resample = conv_resample
|
||||
self.use_checkpoint = use_checkpoint
|
||||
self.dtype = th.float16 if use_fp16 else th.float32
|
||||
self.num_heads = num_heads
|
||||
self.num_head_channels = num_head_channels
|
||||
self.num_heads_upsample = num_heads_upsample
|
||||
|
||||
time_embed_dim = model_channels * 4
|
||||
self.time_embed = nn.Sequential(
|
||||
linear(model_channels, time_embed_dim),
|
||||
nn.SiLU(),
|
||||
linear(time_embed_dim, time_embed_dim),
|
||||
)
|
||||
|
||||
self.input_blocks = nn.ModuleList(
|
||||
[
|
||||
TimestepEmbedSequential(
|
||||
conv_nd(dims, in_channels, model_channels, 3, padding=1)
|
||||
)
|
||||
]
|
||||
)
|
||||
self._feature_size = model_channels
|
||||
input_block_chans = [model_channels]
|
||||
ch = model_channels
|
||||
ds = 1
|
||||
for level, mult in enumerate(channel_mult):
|
||||
for _ in range(num_res_blocks):
|
||||
layers = [
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=mult * model_channels,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
)
|
||||
]
|
||||
ch = mult * model_channels
|
||||
if ds in attention_resolutions:
|
||||
layers.append(
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
)
|
||||
)
|
||||
self.input_blocks.append(TimestepEmbedSequential(*layers))
|
||||
self._feature_size += ch
|
||||
input_block_chans.append(ch)
|
||||
if level != len(channel_mult) - 1:
|
||||
out_ch = ch
|
||||
self.input_blocks.append(
|
||||
TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
out_channels=out_ch,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
down=True,
|
||||
)
|
||||
if resblock_updown
|
||||
else Downsample(
|
||||
ch, conv_resample, dims=dims, out_channels=out_ch
|
||||
)
|
||||
)
|
||||
)
|
||||
ch = out_ch
|
||||
input_block_chans.append(ch)
|
||||
ds *= 2
|
||||
self._feature_size += ch
|
||||
|
||||
self.middle_block = TimestepEmbedSequential(
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
AttentionBlock(
|
||||
ch,
|
||||
use_checkpoint=use_checkpoint,
|
||||
num_heads=num_heads,
|
||||
num_head_channels=num_head_channels,
|
||||
use_new_attention_order=use_new_attention_order,
|
||||
),
|
||||
ResBlock(
|
||||
ch,
|
||||
time_embed_dim,
|
||||
dropout,
|
||||
dims=dims,
|
||||
use_checkpoint=use_checkpoint,
|
||||
use_scale_shift_norm=use_scale_shift_norm,
|
||||
),
|
||||
)
|
||||
self._feature_size += ch
|
||||
self.pool = pool
|
||||
if pool == "adaptive":
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
zero_module(conv_nd(dims, ch, out_channels, 1)),
|
||||
nn.Flatten(),
|
||||
)
|
||||
elif pool == "attention":
|
||||
assert num_head_channels != -1
|
||||
self.out = nn.Sequential(
|
||||
normalization(ch),
|
||||
nn.SiLU(),
|
||||
AttentionPool2d(
|
||||
(image_size // ds), ch, num_head_channels, out_channels
|
||||
),
|
||||
)
|
||||
elif pool == "spatial":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
nn.ReLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
elif pool == "spatial_v2":
|
||||
self.out = nn.Sequential(
|
||||
nn.Linear(self._feature_size, 2048),
|
||||
normalization(2048),
|
||||
nn.SiLU(),
|
||||
nn.Linear(2048, self.out_channels),
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unexpected {pool} pooling")
|
||||
|
||||
def convert_to_fp16(self):
|
||||
"""
|
||||
Convert the torso of the model to float16.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f16)
|
||||
self.middle_block.apply(convert_module_to_f16)
|
||||
|
||||
def convert_to_fp32(self):
|
||||
"""
|
||||
Convert the torso of the model to float32.
|
||||
"""
|
||||
self.input_blocks.apply(convert_module_to_f32)
|
||||
self.middle_block.apply(convert_module_to_f32)
|
||||
|
||||
def forward(self, x, timesteps):
|
||||
"""
|
||||
Apply the model to an input batch.
|
||||
:param x: an [N x C x ...] Tensor of inputs.
|
||||
:param timesteps: a 1-D batch of timesteps.
|
||||
:return: an [N x K] Tensor of outputs.
|
||||
"""
|
||||
emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
|
||||
|
||||
results = []
|
||||
h = x.type(self.dtype)
|
||||
for module in self.input_blocks:
|
||||
h = module(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = self.middle_block(h, emb)
|
||||
if self.pool.startswith("spatial"):
|
||||
results.append(h.type(x.dtype).mean(dim=(2, 3)))
|
||||
h = th.cat(results, axis=-1)
|
||||
return self.out(h)
|
||||
else:
|
||||
h = h.type(x.dtype)
|
||||
return self.out(h)
|
||||
|
||||
@@ -0,0 +1,267 @@
|
||||
# adopted from
|
||||
# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
||||
# and
|
||||
# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
# and
|
||||
# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
|
||||
#
|
||||
# thanks!
|
||||
|
||||
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from einops import repeat
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
|
||||
|
||||
def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
if schedule == "linear":
|
||||
betas = (
|
||||
torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
|
||||
)
|
||||
|
||||
elif schedule == "cosine":
|
||||
timesteps = (
|
||||
torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s
|
||||
)
|
||||
alphas = timesteps / (1 + cosine_s) * np.pi / 2
|
||||
alphas = torch.cos(alphas).pow(2)
|
||||
alphas = alphas / alphas[0]
|
||||
betas = 1 - alphas[1:] / alphas[:-1]
|
||||
betas = np.clip(betas, a_min=0, a_max=0.999)
|
||||
|
||||
elif schedule == "sqrt_linear":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64)
|
||||
elif schedule == "sqrt":
|
||||
betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5
|
||||
else:
|
||||
raise ValueError(f"schedule '{schedule}' unknown.")
|
||||
return betas.numpy()
|
||||
|
||||
|
||||
def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True):
|
||||
if ddim_discr_method == 'uniform':
|
||||
c = num_ddpm_timesteps // num_ddim_timesteps
|
||||
ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c)))
|
||||
elif ddim_discr_method == 'quad':
|
||||
ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int)
|
||||
else:
|
||||
raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"')
|
||||
|
||||
# assert ddim_timesteps.shape[0] == num_ddim_timesteps
|
||||
# add one to get the final alpha values right (the ones from first scale to data during sampling)
|
||||
steps_out = ddim_timesteps + 1
|
||||
if verbose:
|
||||
print(f'Selected timesteps for ddim sampler: {steps_out}')
|
||||
return steps_out
|
||||
|
||||
|
||||
def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True):
|
||||
# select alphas for computing the variance schedule
|
||||
alphas = alphacums[ddim_timesteps]
|
||||
alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist())
|
||||
|
||||
# according the the formula provided in https://arxiv.org/abs/2010.02502
|
||||
sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev))
|
||||
if verbose:
|
||||
print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}')
|
||||
print(f'For the chosen value of eta, which is {eta}, '
|
||||
f'this results in the following sigma_t schedule for ddim sampler {sigmas}')
|
||||
return sigmas, alphas, alphas_prev
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function,
|
||||
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
||||
:param num_diffusion_timesteps: the number of betas to produce.
|
||||
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
||||
produces the cumulative product of (1-beta) up to that
|
||||
part of the diffusion process.
|
||||
:param max_beta: the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
"""
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return np.array(betas)
|
||||
|
||||
|
||||
def extract_into_tensor(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def checkpoint(func, inputs, params, flag):
|
||||
"""
|
||||
Evaluate a function without caching intermediate activations, allowing for
|
||||
reduced memory at the expense of extra compute in the backward pass.
|
||||
:param func: the function to evaluate.
|
||||
:param inputs: the argument sequence to pass to `func`.
|
||||
:param params: a sequence of parameters `func` depends on but does not
|
||||
explicitly take as arguments.
|
||||
:param flag: if False, disable gradient checkpointing.
|
||||
"""
|
||||
if flag:
|
||||
args = tuple(inputs) + tuple(params)
|
||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||
else:
|
||||
return func(*inputs)
|
||||
|
||||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, run_function, length, *args):
|
||||
ctx.run_function = run_function
|
||||
ctx.input_tensors = list(args[:length])
|
||||
ctx.input_params = list(args[length:])
|
||||
|
||||
with torch.no_grad():
|
||||
output_tensors = ctx.run_function(*ctx.input_tensors)
|
||||
return output_tensors
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *output_grads):
|
||||
ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
|
||||
with torch.enable_grad():
|
||||
# Fixes a bug where the first op in run_function modifies the
|
||||
# Tensor storage in place, which is not allowed for detach()'d
|
||||
# Tensors.
|
||||
shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
|
||||
output_tensors = ctx.run_function(*shallow_copies)
|
||||
input_grads = torch.autograd.grad(
|
||||
output_tensors,
|
||||
ctx.input_tensors + ctx.input_params,
|
||||
output_grads,
|
||||
allow_unused=True,
|
||||
)
|
||||
del ctx.input_tensors
|
||||
del ctx.input_params
|
||||
del output_tensors
|
||||
return (None, None) + input_grads
|
||||
|
||||
|
||||
def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
||||
These may be fractional.
|
||||
:param dim: the dimension of the output.
|
||||
:param max_period: controls the minimum frequency of the embeddings.
|
||||
:return: an [N x dim] Tensor of positional embeddings.
|
||||
"""
|
||||
if not repeat_only:
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
||||
).to(device=timesteps.device)
|
||||
args = timesteps[:, None].float() * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
else:
|
||||
embedding = repeat(timesteps, 'b -> b d', d=dim)
|
||||
return embedding
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
"""
|
||||
Zero out the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
def scale_module(module, scale):
|
||||
"""
|
||||
Scale the parameters of a module and return it.
|
||||
"""
|
||||
for p in module.parameters():
|
||||
p.detach().mul_(scale)
|
||||
return module
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def normalization(channels):
|
||||
"""
|
||||
Make a standard normalization layer.
|
||||
:param channels: number of input channels.
|
||||
:return: an nn.Module for normalization.
|
||||
"""
|
||||
return GroupNorm32(32, channels)
|
||||
|
||||
|
||||
# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
|
||||
class SiLU(nn.Module):
|
||||
def forward(self, x):
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class GroupNorm32(nn.GroupNorm):
|
||||
def forward(self, x):
|
||||
return super().forward(x.float()).type(x.dtype)
|
||||
|
||||
def conv_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D convolution module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.Conv1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.Conv2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.Conv3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
def linear(*args, **kwargs):
|
||||
"""
|
||||
Create a linear module.
|
||||
"""
|
||||
return nn.Linear(*args, **kwargs)
|
||||
|
||||
|
||||
def avg_pool_nd(dims, *args, **kwargs):
|
||||
"""
|
||||
Create a 1D, 2D, or 3D average pooling module.
|
||||
"""
|
||||
if dims == 1:
|
||||
return nn.AvgPool1d(*args, **kwargs)
|
||||
elif dims == 2:
|
||||
return nn.AvgPool2d(*args, **kwargs)
|
||||
elif dims == 3:
|
||||
return nn.AvgPool3d(*args, **kwargs)
|
||||
raise ValueError(f"unsupported dimensions: {dims}")
|
||||
|
||||
|
||||
class HybridConditioner(nn.Module):
|
||||
|
||||
def __init__(self, c_concat_config, c_crossattn_config):
|
||||
super().__init__()
|
||||
self.concat_conditioner = instantiate_from_config(c_concat_config)
|
||||
self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
|
||||
|
||||
def forward(self, c_concat, c_crossattn):
|
||||
c_concat = self.concat_conditioner(c_concat)
|
||||
c_crossattn = self.crossattn_conditioner(c_crossattn)
|
||||
return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
noise = lambda: torch.randn(shape, device=device)
|
||||
return repeat_noise() if repeat else noise()
|
||||
@@ -0,0 +1,92 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
||||
class AbstractDistribution:
|
||||
def sample(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def mode(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class DiracDistribution(AbstractDistribution):
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def sample(self):
|
||||
return self.value
|
||||
|
||||
def mode(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class DiagonalGaussianDistribution(object):
|
||||
def __init__(self, parameters, deterministic=False):
|
||||
self.parameters = parameters
|
||||
self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
|
||||
self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
|
||||
self.deterministic = deterministic
|
||||
self.std = torch.exp(0.5 * self.logvar)
|
||||
self.var = torch.exp(self.logvar)
|
||||
if self.deterministic:
|
||||
self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
|
||||
|
||||
def sample(self):
|
||||
x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
|
||||
return x
|
||||
|
||||
def kl(self, other=None):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
else:
|
||||
if other is None:
|
||||
return 0.5 * torch.sum(torch.pow(self.mean, 2)
|
||||
+ self.var - 1.0 - self.logvar,
|
||||
dim=[1, 2, 3])
|
||||
else:
|
||||
return 0.5 * torch.sum(
|
||||
torch.pow(self.mean - other.mean, 2) / other.var
|
||||
+ self.var / other.var - 1.0 - self.logvar + other.logvar,
|
||||
dim=[1, 2, 3])
|
||||
|
||||
def nll(self, sample, dims=[1,2,3]):
|
||||
if self.deterministic:
|
||||
return torch.Tensor([0.])
|
||||
logtwopi = np.log(2.0 * np.pi)
|
||||
return 0.5 * torch.sum(
|
||||
logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
|
||||
dim=dims)
|
||||
|
||||
def mode(self):
|
||||
return self.mean
|
||||
|
||||
|
||||
def normal_kl(mean1, logvar1, mean2, logvar2):
|
||||
"""
|
||||
source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
|
||||
Compute the KL divergence between two gaussians.
|
||||
Shapes are automatically broadcasted, so batches can be compared to
|
||||
scalars, among other use cases.
|
||||
"""
|
||||
tensor = None
|
||||
for obj in (mean1, logvar1, mean2, logvar2):
|
||||
if isinstance(obj, torch.Tensor):
|
||||
tensor = obj
|
||||
break
|
||||
assert tensor is not None, "at least one argument must be a Tensor"
|
||||
|
||||
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
||||
# Tensors, but it does not work for torch.exp().
|
||||
logvar1, logvar2 = [
|
||||
x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
|
||||
for x in (logvar1, logvar2)
|
||||
]
|
||||
|
||||
return 0.5 * (
|
||||
-1.0
|
||||
+ logvar2
|
||||
- logvar1
|
||||
+ torch.exp(logvar1 - logvar2)
|
||||
+ ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
|
||||
)
|
||||
@@ -0,0 +1 @@
|
||||
from .clip import *
|
||||
@@ -0,0 +1,200 @@
|
||||
import hashlib
|
||||
import os
|
||||
import urllib
|
||||
import warnings
|
||||
from typing import Any, Union, List
|
||||
from pkg_resources import packaging
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders.clip.model import build_model
|
||||
|
||||
try:
|
||||
from torchvision.transforms import InterpolationMode
|
||||
BICUBIC = InterpolationMode.BICUBIC
|
||||
except ImportError:
|
||||
BICUBIC = Image.BICUBIC
|
||||
|
||||
|
||||
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.7.1"):
|
||||
warnings.warn("PyTorch version 1.7.1 or higher is recommended")
|
||||
|
||||
|
||||
__all__ = ["available_models", "load"]
|
||||
|
||||
_MODELS = {
|
||||
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
|
||||
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
|
||||
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
|
||||
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
|
||||
"RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt",
|
||||
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
|
||||
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
|
||||
"ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt",
|
||||
"ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt",
|
||||
}
|
||||
|
||||
|
||||
def _download(url: str, root: str):
|
||||
os.makedirs(root, exist_ok=True)
|
||||
filename = os.path.basename(url)
|
||||
|
||||
expected_sha256 = url.split("/")[-2]
|
||||
download_target = os.path.join(root, filename)
|
||||
|
||||
if os.path.exists(download_target) and not os.path.isfile(download_target):
|
||||
raise RuntimeError(f"{download_target} exists and is not a regular file")
|
||||
|
||||
if os.path.isfile(download_target):
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
|
||||
return download_target
|
||||
else:
|
||||
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
|
||||
|
||||
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
|
||||
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True, unit_divisor=1024) as loop:
|
||||
while True:
|
||||
buffer = source.read(8192)
|
||||
if not buffer:
|
||||
break
|
||||
|
||||
output.write(buffer)
|
||||
loop.update(len(buffer))
|
||||
|
||||
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
|
||||
raise RuntimeError("Model has been downloaded but the SHA256 checksum does not not match")
|
||||
|
||||
return download_target
|
||||
|
||||
|
||||
def _convert_image_to_rgb(image):
|
||||
return image.convert("RGB")
|
||||
|
||||
|
||||
def _transform(n_px):
|
||||
return Compose([
|
||||
Resize(n_px, interpolation=BICUBIC),
|
||||
CenterCrop(n_px),
|
||||
_convert_image_to_rgb,
|
||||
ToTensor(),
|
||||
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
|
||||
])
|
||||
|
||||
|
||||
def available_models() -> List[str]:
|
||||
"""Returns the names of available CLIP models"""
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit: bool = False, download_root: str = None):
|
||||
"""Load a CLIP model
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
|
||||
|
||||
device : Union[str, torch.device]
|
||||
The device to put the loaded model
|
||||
|
||||
jit : bool
|
||||
Whether to load the optimized JIT model or more hackable non-JIT model (default).
|
||||
|
||||
download_root: str
|
||||
path to download the model files; by default, it uses "~/.cache/clip"
|
||||
|
||||
Returns
|
||||
-------
|
||||
model : torch.nn.Module
|
||||
The CLIP model
|
||||
|
||||
preprocess : Callable[[PIL.Image], torch.Tensor]
|
||||
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
|
||||
"""
|
||||
if name in _MODELS:
|
||||
model_path = _download(_MODELS[name], download_root or os.path.expanduser("~/.cache/clip"))
|
||||
elif os.path.isfile(name):
|
||||
model_path = name
|
||||
else:
|
||||
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
|
||||
|
||||
with open(model_path, 'rb') as opened_file:
|
||||
try:
|
||||
# loading JIT archive
|
||||
model = torch.jit.load(opened_file, map_location=device if jit else "cpu").eval()
|
||||
state_dict = None
|
||||
except RuntimeError:
|
||||
# loading saved state dict
|
||||
if jit:
|
||||
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
|
||||
jit = False
|
||||
state_dict = torch.load(opened_file, map_location="cpu")
|
||||
|
||||
if not jit:
|
||||
model = build_model(state_dict or model.state_dict()).to(device)
|
||||
if str(device) == "cpu":
|
||||
model.float()
|
||||
return model, _transform(model.visual.input_resolution)
|
||||
|
||||
# patch the device names
|
||||
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
|
||||
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
|
||||
|
||||
def _node_get(node: torch._C.Node, key: str):
|
||||
"""Gets attributes of a node which is polymorphic over return type.
|
||||
|
||||
From https://github.com/pytorch/pytorch/pull/82628
|
||||
"""
|
||||
sel = node.kindOf(key)
|
||||
return getattr(node, sel)(key)
|
||||
|
||||
def patch_device(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("prim::Constant"):
|
||||
if "value" in node.attributeNames() and str(_node_get(node, "value")).startswith("cuda"):
|
||||
node.copyAttributes(device_node)
|
||||
|
||||
model.apply(patch_device)
|
||||
patch_device(model.encode_image)
|
||||
patch_device(model.encode_text)
|
||||
|
||||
# patch dtype to float32 on CPU
|
||||
if str(device) == "cpu":
|
||||
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
|
||||
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
|
||||
float_node = float_input.node()
|
||||
|
||||
def patch_float(module):
|
||||
try:
|
||||
graphs = [module.graph] if hasattr(module, "graph") else []
|
||||
except RuntimeError:
|
||||
graphs = []
|
||||
|
||||
if hasattr(module, "forward1"):
|
||||
graphs.append(module.forward1.graph)
|
||||
|
||||
for graph in graphs:
|
||||
for node in graph.findAllNodes("aten::to"):
|
||||
inputs = list(node.inputs())
|
||||
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
|
||||
if _node_get(inputs[i].node(), "value") == 5:
|
||||
inputs[i].node().copyAttributes(float_node)
|
||||
|
||||
model.apply(patch_float)
|
||||
patch_float(model.encode_image)
|
||||
patch_float(model.encode_text)
|
||||
|
||||
model.float()
|
||||
|
||||
return model, _transform(model.input_resolution.item())
|
||||
@@ -0,0 +1,436 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1):
|
||||
super().__init__()
|
||||
|
||||
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
|
||||
|
||||
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
|
||||
self.downsample = None
|
||||
self.stride = stride
|
||||
|
||||
if stride > 1 or inplanes != planes * Bottleneck.expansion:
|
||||
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
|
||||
self.downsample = nn.Sequential(OrderedDict([
|
||||
("-1", nn.AvgPool2d(stride)),
|
||||
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
|
||||
("1", nn.BatchNorm2d(planes * self.expansion))
|
||||
]))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
identity = x
|
||||
|
||||
out = self.relu1(self.bn1(self.conv1(x)))
|
||||
out = self.relu2(self.bn2(self.conv2(out)))
|
||||
out = self.avgpool(out)
|
||||
out = self.bn3(self.conv3(out))
|
||||
|
||||
if self.downsample is not None:
|
||||
identity = self.downsample(x)
|
||||
|
||||
out += identity
|
||||
out = self.relu3(out)
|
||||
return out
|
||||
|
||||
|
||||
class AttentionPool2d(nn.Module):
|
||||
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
|
||||
super().__init__()
|
||||
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
|
||||
self.k_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.q_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.v_proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
|
||||
self.num_heads = num_heads
|
||||
|
||||
def forward(self, x):
|
||||
x = x.flatten(start_dim=2).permute(2, 0, 1) # NCHW -> (HW)NC
|
||||
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
|
||||
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
|
||||
x, _ = F.multi_head_attention_forward(
|
||||
query=x[:1], key=x, value=x,
|
||||
embed_dim_to_check=x.shape[-1],
|
||||
num_heads=self.num_heads,
|
||||
q_proj_weight=self.q_proj.weight,
|
||||
k_proj_weight=self.k_proj.weight,
|
||||
v_proj_weight=self.v_proj.weight,
|
||||
in_proj_weight=None,
|
||||
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
|
||||
bias_k=None,
|
||||
bias_v=None,
|
||||
add_zero_attn=False,
|
||||
dropout_p=0,
|
||||
out_proj_weight=self.c_proj.weight,
|
||||
out_proj_bias=self.c_proj.bias,
|
||||
use_separate_proj_weight=True,
|
||||
training=self.training,
|
||||
need_weights=False
|
||||
)
|
||||
return x.squeeze(0)
|
||||
|
||||
|
||||
class ModifiedResNet(nn.Module):
|
||||
"""
|
||||
A ResNet class that is similar to torchvision's but contains the following changes:
|
||||
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
|
||||
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
|
||||
- The final pooling layer is a QKV attention instead of an average pool
|
||||
"""
|
||||
|
||||
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
self.input_resolution = input_resolution
|
||||
|
||||
# the 3-layer stem
|
||||
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width // 2)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width // 2)
|
||||
self.relu2 = nn.ReLU(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(width)
|
||||
self.relu3 = nn.ReLU(inplace=True)
|
||||
self.avgpool = nn.AvgPool2d(2)
|
||||
|
||||
# residual layers
|
||||
self._inplanes = width # this is a *mutable* variable used during construction
|
||||
self.layer1 = self._make_layer(width, layers[0])
|
||||
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
|
||||
|
||||
embed_dim = width * 32 # the ResNet feature dimension
|
||||
self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim)
|
||||
|
||||
def _make_layer(self, planes, blocks, stride=1):
|
||||
layers = [Bottleneck(self._inplanes, planes, stride)]
|
||||
|
||||
self._inplanes = planes * Bottleneck.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(Bottleneck(self._inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def stem(x):
|
||||
x = self.relu1(self.bn1(self.conv1(x)))
|
||||
x = self.relu2(self.bn2(self.conv2(x)))
|
||||
x = self.relu3(self.bn3(self.conv3(x)))
|
||||
x = self.avgpool(x)
|
||||
return x
|
||||
|
||||
x = x.type(self.conv1.weight.dtype)
|
||||
x = stem(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
x = self.attnpool(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
"""Subclass torch's LayerNorm to handle fp16."""
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
orig_type = x.dtype
|
||||
ret = super().forward(x.type(torch.float32))
|
||||
return ret.type(orig_type)
|
||||
|
||||
|
||||
class QuickGELU(nn.Module):
|
||||
def forward(self, x: torch.Tensor):
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
|
||||
self.attn = nn.MultiheadAttention(d_model, n_head)
|
||||
self.ln_1 = LayerNorm(d_model)
|
||||
self.mlp = nn.Sequential(OrderedDict([
|
||||
("c_fc", nn.Linear(d_model, d_model * 4)),
|
||||
("gelu", QuickGELU()),
|
||||
("c_proj", nn.Linear(d_model * 4, d_model))
|
||||
]))
|
||||
self.ln_2 = LayerNorm(d_model)
|
||||
self.attn_mask = attn_mask
|
||||
|
||||
def attention(self, x: torch.Tensor):
|
||||
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
|
||||
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = x + self.attention(self.ln_1(x))
|
||||
x = x + self.mlp(self.ln_2(x))
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
|
||||
super().__init__()
|
||||
self.width = width
|
||||
self.layers = layers
|
||||
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
return self.resblocks(x)
|
||||
|
||||
|
||||
class VisionTransformer(nn.Module):
|
||||
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
|
||||
super().__init__()
|
||||
self.input_resolution = input_resolution
|
||||
self.output_dim = output_dim
|
||||
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
|
||||
|
||||
scale = width ** -0.5
|
||||
self.class_embedding = nn.Parameter(scale * torch.randn(width))
|
||||
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
|
||||
self.ln_pre = LayerNorm(width)
|
||||
|
||||
self.transformer = Transformer(width, layers, heads)
|
||||
|
||||
self.ln_post = LayerNorm(width)
|
||||
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
x = self.conv1(x) # shape = [*, width, grid, grid]
|
||||
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
|
||||
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
|
||||
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
|
||||
x = x + self.positional_embedding.to(x.dtype)
|
||||
x = self.ln_pre(x)
|
||||
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
|
||||
x = self.ln_post(x[:, 0, :])
|
||||
|
||||
if self.proj is not None:
|
||||
x = x @ self.proj
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIP(nn.Module):
|
||||
def __init__(self,
|
||||
embed_dim: int,
|
||||
# vision
|
||||
image_resolution: int,
|
||||
vision_layers: Union[Tuple[int, int, int, int], int],
|
||||
vision_width: int,
|
||||
vision_patch_size: int,
|
||||
# text
|
||||
context_length: int,
|
||||
vocab_size: int,
|
||||
transformer_width: int,
|
||||
transformer_heads: int,
|
||||
transformer_layers: int
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.context_length = context_length
|
||||
|
||||
if isinstance(vision_layers, (tuple, list)):
|
||||
vision_heads = vision_width * 32 // 64
|
||||
self.visual = ModifiedResNet(
|
||||
layers=vision_layers,
|
||||
output_dim=embed_dim,
|
||||
heads=vision_heads,
|
||||
input_resolution=image_resolution,
|
||||
width=vision_width
|
||||
)
|
||||
else:
|
||||
vision_heads = vision_width // 64
|
||||
self.visual = VisionTransformer(
|
||||
input_resolution=image_resolution,
|
||||
patch_size=vision_patch_size,
|
||||
width=vision_width,
|
||||
layers=vision_layers,
|
||||
heads=vision_heads,
|
||||
output_dim=embed_dim
|
||||
)
|
||||
|
||||
self.transformer = Transformer(
|
||||
width=transformer_width,
|
||||
layers=transformer_layers,
|
||||
heads=transformer_heads,
|
||||
attn_mask=self.build_attention_mask()
|
||||
)
|
||||
|
||||
self.vocab_size = vocab_size
|
||||
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
|
||||
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
|
||||
self.ln_final = LayerNorm(transformer_width)
|
||||
|
||||
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
|
||||
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
||||
|
||||
self.initialize_parameters()
|
||||
|
||||
def initialize_parameters(self):
|
||||
nn.init.normal_(self.token_embedding.weight, std=0.02)
|
||||
nn.init.normal_(self.positional_embedding, std=0.01)
|
||||
|
||||
if isinstance(self.visual, ModifiedResNet):
|
||||
if self.visual.attnpool is not None:
|
||||
std = self.visual.attnpool.c_proj.in_features ** -0.5
|
||||
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
|
||||
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
|
||||
|
||||
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
|
||||
for name, param in resnet_block.named_parameters():
|
||||
if name.endswith("bn3.weight"):
|
||||
nn.init.zeros_(param)
|
||||
|
||||
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
|
||||
attn_std = self.transformer.width ** -0.5
|
||||
fc_std = (2 * self.transformer.width) ** -0.5
|
||||
for block in self.transformer.resblocks:
|
||||
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
|
||||
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
|
||||
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
|
||||
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
|
||||
|
||||
if self.text_projection is not None:
|
||||
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
|
||||
|
||||
def build_attention_mask(self):
|
||||
# lazily create causal attention mask, with full attention between the vision tokens
|
||||
# pytorch uses additive attention mask; fill with -inf
|
||||
mask = torch.empty(self.context_length, self.context_length)
|
||||
mask.fill_(float("-inf"))
|
||||
mask.triu_(1) # zero out the lower diagonal
|
||||
return mask
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.visual.conv1.weight.dtype
|
||||
|
||||
def encode_image(self, image):
|
||||
return self.visual(image.type(self.dtype))
|
||||
|
||||
def encode_text(self, text):
|
||||
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
|
||||
|
||||
x = x + self.positional_embedding.type(self.dtype)
|
||||
x = x.permute(1, 0, 2) # NLD -> LND
|
||||
x = self.transformer(x)
|
||||
x = x.permute(1, 0, 2) # LND -> NLD
|
||||
x = self.ln_final(x).type(self.dtype)
|
||||
|
||||
# x.shape = [batch_size, n_ctx, transformer.width]
|
||||
# take features from the eot embedding (eot_token is the highest number in each sequence)
|
||||
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
|
||||
|
||||
return x
|
||||
|
||||
def forward(self, image, text):
|
||||
image_features = self.encode_image(image)
|
||||
text_features = self.encode_text(text)
|
||||
|
||||
# normalized features
|
||||
image_features = image_features / image_features.norm(dim=1, keepdim=True)
|
||||
text_features = text_features / text_features.norm(dim=1, keepdim=True)
|
||||
|
||||
# cosine similarity as logits
|
||||
logit_scale = self.logit_scale.exp()
|
||||
logits_per_image = logit_scale * image_features @ text_features.t()
|
||||
logits_per_text = logits_per_image.t()
|
||||
|
||||
# shape = [global_batch_size, global_batch_size]
|
||||
return logits_per_image, logits_per_text
|
||||
|
||||
|
||||
def convert_weights(model: nn.Module):
|
||||
"""Convert applicable model parameters to fp16"""
|
||||
|
||||
def _convert_weights_to_fp16(l):
|
||||
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
|
||||
l.weight.data = l.weight.data.half()
|
||||
if l.bias is not None:
|
||||
l.bias.data = l.bias.data.half()
|
||||
|
||||
if isinstance(l, nn.MultiheadAttention):
|
||||
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
|
||||
tensor = getattr(l, attr)
|
||||
if tensor is not None:
|
||||
tensor.data = tensor.data.half()
|
||||
|
||||
for name in ["text_projection", "proj"]:
|
||||
if hasattr(l, name):
|
||||
attr = getattr(l, name)
|
||||
if attr is not None:
|
||||
attr.data = attr.data.half()
|
||||
|
||||
model.apply(_convert_weights_to_fp16)
|
||||
|
||||
|
||||
def build_model(state_dict: dict):
|
||||
vit = "visual.proj" in state_dict
|
||||
|
||||
if vit:
|
||||
vision_width = state_dict["visual.conv1.weight"].shape[0]
|
||||
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
|
||||
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
|
||||
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
image_resolution = vision_patch_size * grid_size
|
||||
else:
|
||||
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
|
||||
vision_layers = tuple(counts)
|
||||
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
|
||||
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
|
||||
vision_patch_size = None
|
||||
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
|
||||
image_resolution = output_width * 32
|
||||
|
||||
embed_dim = state_dict["text_projection"].shape[1]
|
||||
context_length = state_dict["positional_embedding"].shape[0]
|
||||
vocab_size = state_dict["token_embedding.weight"].shape[0]
|
||||
transformer_width = state_dict["ln_final.weight"].shape[0]
|
||||
transformer_heads = transformer_width // 64
|
||||
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith("transformer.resblocks")))
|
||||
|
||||
model = CLIP(
|
||||
embed_dim,
|
||||
image_resolution, vision_layers, vision_width, vision_patch_size,
|
||||
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
|
||||
)
|
||||
|
||||
for key in ["input_resolution", "context_length", "vocab_size"]:
|
||||
if key in state_dict:
|
||||
del state_dict[key]
|
||||
|
||||
convert_weights(model)
|
||||
model.load_state_dict(state_dict)
|
||||
return model.eval()
|
||||
@@ -0,0 +1,132 @@
|
||||
import gzip
|
||||
import html
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import ftfy
|
||||
import regex as re
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def default_bpe():
|
||||
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def bytes_to_unicode():
|
||||
"""
|
||||
Returns list of utf-8 byte and a corresponding list of unicode strings.
|
||||
The reversible bpe codes work on unicode strings.
|
||||
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
|
||||
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
|
||||
This is a signficant percentage of your normal, say, 32K bpe vocab.
|
||||
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
|
||||
And avoids mapping to whitespace/control characters the bpe code barfs on.
|
||||
"""
|
||||
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
|
||||
cs = bs[:]
|
||||
n = 0
|
||||
for b in range(2**8):
|
||||
if b not in bs:
|
||||
bs.append(b)
|
||||
cs.append(2**8+n)
|
||||
n += 1
|
||||
cs = [chr(n) for n in cs]
|
||||
return dict(zip(bs, cs))
|
||||
|
||||
|
||||
def get_pairs(word):
|
||||
"""Return set of symbol pairs in a word.
|
||||
Word is represented as tuple of symbols (symbols being variable-length strings).
|
||||
"""
|
||||
pairs = set()
|
||||
prev_char = word[0]
|
||||
for char in word[1:]:
|
||||
pairs.add((prev_char, char))
|
||||
prev_char = char
|
||||
return pairs
|
||||
|
||||
|
||||
def basic_clean(text):
|
||||
text = ftfy.fix_text(text)
|
||||
text = html.unescape(html.unescape(text))
|
||||
return text.strip()
|
||||
|
||||
|
||||
def whitespace_clean(text):
|
||||
text = re.sub(r'\s+', ' ', text)
|
||||
text = text.strip()
|
||||
return text
|
||||
|
||||
|
||||
class SimpleTokenizer(object):
|
||||
def __init__(self, bpe_path: str = default_bpe()):
|
||||
self.byte_encoder = bytes_to_unicode()
|
||||
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
|
||||
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
|
||||
merges = merges[1:49152-256-2+1]
|
||||
merges = [tuple(merge.split()) for merge in merges]
|
||||
vocab = list(bytes_to_unicode().values())
|
||||
vocab = vocab + [v+'</w>' for v in vocab]
|
||||
for merge in merges:
|
||||
vocab.append(''.join(merge))
|
||||
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
|
||||
self.encoder = dict(zip(vocab, range(len(vocab))))
|
||||
self.decoder = {v: k for k, v in self.encoder.items()}
|
||||
self.bpe_ranks = dict(zip(merges, range(len(merges))))
|
||||
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
|
||||
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
|
||||
|
||||
def bpe(self, token):
|
||||
if token in self.cache:
|
||||
return self.cache[token]
|
||||
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
|
||||
pairs = get_pairs(word)
|
||||
|
||||
if not pairs:
|
||||
return token+'</w>'
|
||||
|
||||
while True:
|
||||
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
first, second = bigram
|
||||
new_word = []
|
||||
i = 0
|
||||
while i < len(word):
|
||||
try:
|
||||
j = word.index(first, i)
|
||||
new_word.extend(word[i:j])
|
||||
i = j
|
||||
except:
|
||||
new_word.extend(word[i:])
|
||||
break
|
||||
|
||||
if word[i] == first and i < len(word)-1 and word[i+1] == second:
|
||||
new_word.append(first+second)
|
||||
i += 2
|
||||
else:
|
||||
new_word.append(word[i])
|
||||
i += 1
|
||||
new_word = tuple(new_word)
|
||||
word = new_word
|
||||
if len(word) == 1:
|
||||
break
|
||||
else:
|
||||
pairs = get_pairs(word)
|
||||
word = ' '.join(word)
|
||||
self.cache[token] = word
|
||||
return word
|
||||
|
||||
def encode(self, text):
|
||||
bpe_tokens = []
|
||||
text = whitespace_clean(basic_clean(text)).lower()
|
||||
for token in re.findall(self.pat, text):
|
||||
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
|
||||
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
|
||||
return bpe_tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
text = ''.join([self.decoder[token] for token in tokens])
|
||||
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
|
||||
return text
|
||||
551
modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py
Normal file
551
modelscope/models/cv/image_to_3d/ldm/modules/encoders/modules.py
Normal file
@@ -0,0 +1,551 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
import kornia
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import default
|
||||
# import clip
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.encoders import clip
|
||||
|
||||
|
||||
class AbstractEncoder(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def encode(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
class IdentityEncoder(AbstractEncoder):
|
||||
|
||||
def encode(self, x):
|
||||
return x
|
||||
|
||||
class FaceClipEncoder(AbstractEncoder):
|
||||
def __init__(self, augment=True, retreival_key=None):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
self.augment = augment
|
||||
self.retreival_key = retreival_key
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
x_offset = 125
|
||||
if self.retreival_key:
|
||||
# Assumes retrieved image are packed into the second half of channels
|
||||
face = img[:,3:,190:440,x_offset:(512-x_offset)]
|
||||
other = img[:,:3,...].clone()
|
||||
else:
|
||||
face = img[:,:,190:440,x_offset:(512-x_offset)]
|
||||
other = img.clone()
|
||||
|
||||
if self.augment:
|
||||
face = K.RandomHorizontalFlip()(face)
|
||||
|
||||
other[:,:,190:440,x_offset:(512-x_offset)] *= 0
|
||||
encodings = [
|
||||
self.encoder.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
class FaceIdClipEncoder(AbstractEncoder):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.encoder = FrozenCLIPImageEmbedder()
|
||||
for p in self.encoder.parameters():
|
||||
p.requires_grad = False
|
||||
self.id = FrozenFaceEncoder("/home/jpinkney/code/stable-diffusion/model_ir_se50.pth", augment=True)
|
||||
|
||||
def forward(self, img):
|
||||
encodings = []
|
||||
with torch.no_grad():
|
||||
face = kornia.geometry.resize(img, (256, 256),
|
||||
interpolation='bilinear', align_corners=True)
|
||||
|
||||
other = img.clone()
|
||||
other[:,:,184:452,122:396] *= 0
|
||||
encodings = [
|
||||
self.id.encode(face),
|
||||
self.encoder.encode(other),
|
||||
]
|
||||
|
||||
return torch.cat(encodings, dim=1)
|
||||
|
||||
def encode(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 2, 768), device=self.encoder.model.visual.conv1.weight.device)
|
||||
|
||||
return self(img)
|
||||
|
||||
class ClassEmbedder(nn.Module):
|
||||
def __init__(self, embed_dim, n_classes=1000, key='class'):
|
||||
super().__init__()
|
||||
self.key = key
|
||||
self.embedding = nn.Embedding(n_classes, embed_dim)
|
||||
|
||||
def forward(self, batch, key=None):
|
||||
if key is None:
|
||||
key = self.key
|
||||
# this is for use in crossattn
|
||||
c = batch[key][:, None]
|
||||
c = self.embedding(c)
|
||||
return c
|
||||
|
||||
|
||||
class TransformerEmbedder(AbstractEncoder):
|
||||
"""Some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"):
|
||||
super().__init__()
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer))
|
||||
|
||||
def forward(self, tokens):
|
||||
tokens = tokens.to(self.device) # meh
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
class BERTTokenizer(AbstractEncoder):
|
||||
""" Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
|
||||
def __init__(self, device="cuda", vq_interface=True, max_length=77):
|
||||
super().__init__()
|
||||
from transformers import BertTokenizerFast # TODO: add to reuquirements
|
||||
self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
|
||||
self.device = device
|
||||
self.vq_interface = vq_interface
|
||||
self.max_length = max_length
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
return tokens
|
||||
|
||||
@torch.no_grad()
|
||||
def encode(self, text):
|
||||
tokens = self(text)
|
||||
if not self.vq_interface:
|
||||
return tokens
|
||||
return None, None, [None, None, tokens]
|
||||
|
||||
def decode(self, text):
|
||||
return text
|
||||
|
||||
|
||||
class BERTEmbedder(AbstractEncoder):
|
||||
"""Uses the BERT tokenizr model and add some transformer encoder layers"""
|
||||
def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77,
|
||||
device="cuda",use_tokenizer=True, embedding_dropout=0.0):
|
||||
super().__init__()
|
||||
self.use_tknz_fn = use_tokenizer
|
||||
if self.use_tknz_fn:
|
||||
self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len)
|
||||
self.device = device
|
||||
self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
|
||||
attn_layers=Encoder(dim=n_embed, depth=n_layer),
|
||||
emb_dropout=embedding_dropout)
|
||||
|
||||
def forward(self, text):
|
||||
if self.use_tknz_fn:
|
||||
tokens = self.tknz_fn(text)#.to(self.device)
|
||||
else:
|
||||
tokens = text
|
||||
z = self.transformer(tokens, return_embeddings=True)
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
# output of length 77
|
||||
return self(text)
|
||||
|
||||
|
||||
from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
|
||||
|
||||
def disabled_train(self, mode=True):
|
||||
"""Overwrite model.train with this function to make sure train/eval mode
|
||||
does not change anymore."""
|
||||
return self
|
||||
|
||||
|
||||
class FrozenT5Embedder(AbstractEncoder):
|
||||
"""Uses the T5 transformer encoder for text"""
|
||||
def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
|
||||
super().__init__()
|
||||
self.tokenizer = T5Tokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.transformer = T5EncoderModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.id_loss import IDFeatures
|
||||
import kornia.augmentation as K
|
||||
|
||||
class FrozenFaceEncoder(AbstractEncoder):
|
||||
def __init__(self, model_path, augment=False):
|
||||
super().__init__()
|
||||
self.loss_fn = IDFeatures(model_path)
|
||||
# face encoder is frozen
|
||||
for p in self.loss_fn.parameters():
|
||||
p.requires_grad = False
|
||||
# Mapper is trainable
|
||||
self.mapper = torch.nn.Linear(512, 768)
|
||||
p = 0.25
|
||||
if augment:
|
||||
self.augment = K.AugmentationSequential(
|
||||
K.RandomHorizontalFlip(p=0.5),
|
||||
K.RandomEqualize(p=p),
|
||||
# K.RandomPlanckianJitter(p=p),
|
||||
# K.RandomPlasmaBrightness(p=p),
|
||||
# K.RandomPlasmaContrast(p=p),
|
||||
# K.ColorJiggle(0.02, 0.2, 0.2, p=p),
|
||||
)
|
||||
else:
|
||||
self.augment = False
|
||||
|
||||
def forward(self, img):
|
||||
if isinstance(img, list):
|
||||
# Uncondition
|
||||
return torch.zeros((1, 1, 768), device=self.mapper.weight.device)
|
||||
|
||||
if self.augment is not None:
|
||||
# Transforms require 0-1
|
||||
img = self.augment((img + 1)/2)
|
||||
img = 2*img - 1
|
||||
|
||||
feat = self.loss_fn(img, crop=True)
|
||||
feat = self.mapper(feat.unsqueeze(1))
|
||||
return feat
|
||||
|
||||
def encode(self, img):
|
||||
return self(img)
|
||||
|
||||
class FrozenCLIPEmbedder(AbstractEncoder):
|
||||
"""Uses the CLIP transformer encoder for text (from huggingface)"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.tokenizer = CLIPTokenizer.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.transformer = CLIPTextModel.from_pretrained(version, cache_dir='/apdcephfs/private_rondyliu/projects/huggingface_models')
|
||||
self.device = device
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.freeze()
|
||||
|
||||
def freeze(self):
|
||||
self.transformer = self.transformer.eval()
|
||||
#self.train = disabled_train
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, text):
|
||||
batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
|
||||
return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
|
||||
tokens = batch_encoding["input_ids"].to(self.device)
|
||||
outputs = self.transformer(input_ids=tokens)
|
||||
|
||||
z = outputs.last_hidden_state
|
||||
return z
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
import torch.nn.functional as F
|
||||
from transformers import CLIPVisionModel
|
||||
class ClipImageProjector(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
"""
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.model = CLIPVisionModel.from_pretrained(version)
|
||||
self.model.train()
|
||||
self.max_length = max_length # TODO: typical value?
|
||||
self.antialias = True
|
||||
self.mapper = torch.nn.Linear(1024, 768)
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
null_cond = self.get_null_cond(version, max_length)
|
||||
self.register_buffer('null_cond', null_cond)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_null_cond(self, version, max_length):
|
||||
device = self.mean.device
|
||||
embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
null_cond = embedder([""])
|
||||
return null_cond
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
if isinstance(x, list):
|
||||
return self.null_cond
|
||||
# x is assumed to be in range [-1,1]
|
||||
x = self.preprocess(x)
|
||||
outputs = self.model(pixel_values=x)
|
||||
last_hidden_state = outputs.last_hidden_state
|
||||
last_hidden_state = self.mapper(last_hidden_state)
|
||||
return F.pad(last_hidden_state, [0,0, 0,self.max_length-last_hidden_state.shape[1], 0,0])
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
class ProjectedFrozenCLIPEmbedder(AbstractEncoder):
|
||||
def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77): # clip-vit-base-patch32
|
||||
super().__init__()
|
||||
self.embedder = FrozenCLIPEmbedder(version=version, device=device, max_length=max_length)
|
||||
self.projection = torch.nn.Linear(768, 768)
|
||||
|
||||
def forward(self, text):
|
||||
z = self.embedder(text)
|
||||
return self.projection(z)
|
||||
|
||||
def encode(self, text):
|
||||
return self(text)
|
||||
|
||||
class FrozenCLIPImageEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
|
||||
def preprocess(self, x):
|
||||
# Expects inputs in the range -1, 1
|
||||
x = kornia.geometry.resize(x, (224, 224),
|
||||
interpolation='bicubic',align_corners=True,
|
||||
antialias=self.antialias)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, 768, device=device)
|
||||
return self.model.encode_image(self.preprocess(x)).float()
|
||||
|
||||
def encode(self, im):
|
||||
return self(im).unsqueeze(1)
|
||||
|
||||
from torchvision import transforms
|
||||
import random
|
||||
|
||||
class FrozenCLIPImageMutliEmbedder(AbstractEncoder):
|
||||
"""
|
||||
Uses the CLIP image encoder.
|
||||
Not actually frozen... If you want that set cond_stage_trainable=False in cfg
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model='ViT-L/14',
|
||||
jit=False,
|
||||
device='cpu',
|
||||
antialias=True,
|
||||
max_crops=5,
|
||||
):
|
||||
super().__init__()
|
||||
self.model, _ = clip.load(name=model, device=device, jit=jit)
|
||||
# We don't use the text part so delete it
|
||||
del self.model.transformer
|
||||
self.antialias = antialias
|
||||
self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
|
||||
self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
|
||||
self.max_crops = max_crops
|
||||
|
||||
def preprocess(self, x):
|
||||
|
||||
# Expects inputs in the range -1, 1
|
||||
randcrop = transforms.RandomResizedCrop(224, scale=(0.085, 1.0), ratio=(1,1))
|
||||
max_crops = self.max_crops
|
||||
patches = []
|
||||
crops = [randcrop(x) for _ in range(max_crops)]
|
||||
patches.extend(crops)
|
||||
x = torch.cat(patches, dim=0)
|
||||
x = (x + 1.) / 2.
|
||||
# renormalize according to clip
|
||||
x = kornia.enhance.normalize(x, self.mean, self.std)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
# x is assumed to be in range [-1,1]
|
||||
if isinstance(x, list):
|
||||
# [""] denotes condition dropout for ucg
|
||||
device = self.model.visual.conv1.weight.device
|
||||
return torch.zeros(1, self.max_crops, 768, device=device)
|
||||
batch_tokens = []
|
||||
for im in x:
|
||||
patches = self.preprocess(im.unsqueeze(0))
|
||||
tokens = self.model.encode_image(patches).float()
|
||||
for t in tokens:
|
||||
if random.random() < 0.1:
|
||||
t *= 0
|
||||
batch_tokens.append(tokens.unsqueeze(0))
|
||||
|
||||
return torch.cat(batch_tokens, dim=0)
|
||||
|
||||
def encode(self, im):
|
||||
return self(im)
|
||||
|
||||
class SpatialRescaler(nn.Module):
|
||||
def __init__(self,
|
||||
n_stages=1,
|
||||
method='bilinear',
|
||||
multiplier=0.5,
|
||||
in_channels=3,
|
||||
out_channels=None,
|
||||
bias=False):
|
||||
super().__init__()
|
||||
self.n_stages = n_stages
|
||||
assert self.n_stages >= 0
|
||||
assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
|
||||
self.multiplier = multiplier
|
||||
self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
|
||||
self.remap_output = out_channels is not None
|
||||
if self.remap_output:
|
||||
print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
|
||||
self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
|
||||
|
||||
def forward(self,x):
|
||||
for stage in range(self.n_stages):
|
||||
x = self.interpolator(x, scale_factor=self.multiplier)
|
||||
|
||||
|
||||
if self.remap_output:
|
||||
x = self.channel_mapper(x)
|
||||
return x
|
||||
|
||||
def encode(self, x):
|
||||
return self(x)
|
||||
|
||||
|
||||
from modelscope.models.cv.image_to_3d.ldm.util import instantiate_from_config
|
||||
from modelscope.models.cv.image_to_3d.ldm.modules.diffusionmodules.util import make_beta_schedule, extract_into_tensor, noise_like
|
||||
|
||||
|
||||
class LowScaleEncoder(nn.Module):
|
||||
def __init__(self, model_config, linear_start, linear_end, timesteps=1000, max_noise_level=250, output_size=64,
|
||||
scale_factor=1.0):
|
||||
super().__init__()
|
||||
self.max_noise_level = max_noise_level
|
||||
self.model = instantiate_from_config(model_config)
|
||||
self.augmentation_schedule = self.register_schedule(timesteps=timesteps, linear_start=linear_start,
|
||||
linear_end=linear_end)
|
||||
self.out_size = output_size
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def register_schedule(self, beta_schedule="linear", timesteps=1000,
|
||||
linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
|
||||
betas = make_beta_schedule(beta_schedule, timesteps, linear_start=linear_start, linear_end=linear_end,
|
||||
cosine_s=cosine_s)
|
||||
alphas = 1. - betas
|
||||
alphas_cumprod = np.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
|
||||
|
||||
timesteps, = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
|
||||
|
||||
to_torch = partial(torch.tensor, dtype=torch.float32)
|
||||
|
||||
self.register_buffer('betas', to_torch(betas))
|
||||
self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
|
||||
self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
|
||||
self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
|
||||
self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
|
||||
self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
|
||||
extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
|
||||
|
||||
def forward(self, x):
|
||||
z = self.model.encode(x).sample()
|
||||
z = z * self.scale_factor
|
||||
noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
|
||||
z = self.q_sample(z, noise_level)
|
||||
if self.out_size is not None:
|
||||
z = torch.nn.functional.interpolate(z, size=self.out_size, mode="nearest") # TODO: experiment with mode
|
||||
# z = z.repeat_interleave(2, -2).repeat_interleave(2, -1)
|
||||
return z, noise_level
|
||||
|
||||
def decode(self, z):
|
||||
z = z / self.scale_factor
|
||||
return self.model.decode(z)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from ldm.util import count_params
|
||||
sentences = ["a hedgehog drinking a whiskey", "der mond ist aufgegangen", "Ein Satz mit vielen Sonderzeichen: äöü ß ?! : 'xx-y/@s'"]
|
||||
model = FrozenT5Embedder(version="google/t5-v1_1-xl").cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
model = FrozenCLIPEmbedder().cuda()
|
||||
count_params(model, True)
|
||||
z = model(sentences)
|
||||
print(z.shape)
|
||||
|
||||
print("done.")
|
||||
641
modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py
Normal file
641
modelscope/models/cv/image_to_3d/ldm/modules/x_transformer.py
Normal file
@@ -0,0 +1,641 @@
|
||||
"""shout-out to https://github.com/lucidrains/x-transformers/tree/main/x_transformers"""
|
||||
import torch
|
||||
from torch import nn, einsum
|
||||
import torch.nn.functional as F
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from collections import namedtuple
|
||||
from einops import rearrange, repeat, reduce
|
||||
|
||||
# constants
|
||||
|
||||
DEFAULT_DIM_HEAD = 64
|
||||
|
||||
Intermediates = namedtuple('Intermediates', [
|
||||
'pre_softmax_attn',
|
||||
'post_softmax_attn'
|
||||
])
|
||||
|
||||
LayerIntermediates = namedtuple('Intermediates', [
|
||||
'hiddens',
|
||||
'attn_intermediates'
|
||||
])
|
||||
|
||||
|
||||
class AbsolutePositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim, max_seq_len):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(max_seq_len, dim)
|
||||
self.init_()
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.emb.weight, std=0.02)
|
||||
|
||||
def forward(self, x):
|
||||
n = torch.arange(x.shape[1], device=x.device)
|
||||
return self.emb(n)[None, :, :]
|
||||
|
||||
|
||||
class FixedPositionalEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
inv_freq = 1. / (10000 ** (torch.arange(0, dim, 2).float() / dim))
|
||||
self.register_buffer('inv_freq', inv_freq)
|
||||
|
||||
def forward(self, x, seq_dim=1, offset=0):
|
||||
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq) + offset
|
||||
sinusoid_inp = torch.einsum('i , j -> i j', t, self.inv_freq)
|
||||
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
|
||||
return emb[None, :, :]
|
||||
|
||||
|
||||
# helpers
|
||||
|
||||
def exists(val):
|
||||
return val is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def always(val):
|
||||
def inner(*args, **kwargs):
|
||||
return val
|
||||
return inner
|
||||
|
||||
|
||||
def not_equals(val):
|
||||
def inner(x):
|
||||
return x != val
|
||||
return inner
|
||||
|
||||
|
||||
def equals(val):
|
||||
def inner(x):
|
||||
return x == val
|
||||
return inner
|
||||
|
||||
|
||||
def max_neg_value(tensor):
|
||||
return -torch.finfo(tensor.dtype).max
|
||||
|
||||
|
||||
# keyword argument helpers
|
||||
|
||||
def pick_and_pop(keys, d):
|
||||
values = list(map(lambda key: d.pop(key), keys))
|
||||
return dict(zip(keys, values))
|
||||
|
||||
|
||||
def group_dict_by_key(cond, d):
|
||||
return_val = [dict(), dict()]
|
||||
for key in d.keys():
|
||||
match = bool(cond(key))
|
||||
ind = int(not match)
|
||||
return_val[ind][key] = d[key]
|
||||
return (*return_val,)
|
||||
|
||||
|
||||
def string_begins_with(prefix, str):
|
||||
return str.startswith(prefix)
|
||||
|
||||
|
||||
def group_by_key_prefix(prefix, d):
|
||||
return group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
|
||||
|
||||
def groupby_prefix_and_trim(prefix, d):
|
||||
kwargs_with_prefix, kwargs = group_dict_by_key(partial(string_begins_with, prefix), d)
|
||||
kwargs_without_prefix = dict(map(lambda x: (x[0][len(prefix):], x[1]), tuple(kwargs_with_prefix.items())))
|
||||
return kwargs_without_prefix, kwargs
|
||||
|
||||
|
||||
# classes
|
||||
class Scale(nn.Module):
|
||||
def __init__(self, value, fn):
|
||||
super().__init__()
|
||||
self.value = value
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.value, *rest)
|
||||
|
||||
|
||||
class Rezero(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.g = nn.Parameter(torch.zeros(1))
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
x, *rest = self.fn(x, **kwargs)
|
||||
return (x * self.g, *rest)
|
||||
|
||||
|
||||
class ScaleNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-8):
|
||||
super().__init__()
|
||||
self.scale = dim ** -0.5
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(dim))
|
||||
|
||||
def forward(self, x):
|
||||
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
|
||||
return x / norm.clamp(min=self.eps) * self.g
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def forward(self, x, residual):
|
||||
return x + residual
|
||||
|
||||
|
||||
class GRUGating(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gru = nn.GRUCell(dim, dim)
|
||||
|
||||
def forward(self, x, residual):
|
||||
gated_output = self.gru(
|
||||
rearrange(x, 'b n d -> (b n) d'),
|
||||
rearrange(residual, 'b n d -> (b n) d')
|
||||
)
|
||||
|
||||
return gated_output.reshape_as(x)
|
||||
|
||||
|
||||
# feedforward
|
||||
|
||||
class GEGLU(nn.Module):
|
||||
def __init__(self, dim_in, dim_out):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out * 2)
|
||||
|
||||
def forward(self, x):
|
||||
x, gate = self.proj(x).chunk(2, dim=-1)
|
||||
return x * F.gelu(gate)
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = default(dim_out, dim)
|
||||
project_in = nn.Sequential(
|
||||
nn.Linear(dim, inner_dim),
|
||||
nn.GELU()
|
||||
) if not glu else GEGLU(dim, inner_dim)
|
||||
|
||||
self.net = nn.Sequential(
|
||||
project_in,
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(inner_dim, dim_out)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# attention.
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
dim_head=DEFAULT_DIM_HEAD,
|
||||
heads=8,
|
||||
causal=False,
|
||||
mask=None,
|
||||
talking_heads=False,
|
||||
sparse_topk=None,
|
||||
use_entmax15=False,
|
||||
num_mem_kv=0,
|
||||
dropout=0.,
|
||||
on_attn=False
|
||||
):
|
||||
super().__init__()
|
||||
if use_entmax15:
|
||||
raise NotImplementedError("Check out entmax activation instead of softmax activation!")
|
||||
self.scale = dim_head ** -0.5
|
||||
self.heads = heads
|
||||
self.causal = causal
|
||||
self.mask = mask
|
||||
|
||||
inner_dim = dim_head * heads
|
||||
|
||||
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
# talking heads
|
||||
self.talking_heads = talking_heads
|
||||
if talking_heads:
|
||||
self.pre_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
self.post_softmax_proj = nn.Parameter(torch.randn(heads, heads))
|
||||
|
||||
# explicit topk sparse attention
|
||||
self.sparse_topk = sparse_topk
|
||||
|
||||
# entmax
|
||||
#self.attn_fn = entmax15 if use_entmax15 else F.softmax
|
||||
self.attn_fn = F.softmax
|
||||
|
||||
# add memory key / values
|
||||
self.num_mem_kv = num_mem_kv
|
||||
if num_mem_kv > 0:
|
||||
self.mem_k = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
self.mem_v = nn.Parameter(torch.randn(heads, num_mem_kv, dim_head))
|
||||
|
||||
# attention on attention
|
||||
self.attn_on_attn = on_attn
|
||||
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim * 2), nn.GLU()) if on_attn else nn.Linear(inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
rel_pos=None,
|
||||
sinusoidal_emb=None,
|
||||
prev_attn=None,
|
||||
mem=None
|
||||
):
|
||||
b, n, _, h, talking_heads, device = *x.shape, self.heads, self.talking_heads, x.device
|
||||
kv_input = default(context, x)
|
||||
|
||||
q_input = x
|
||||
k_input = kv_input
|
||||
v_input = kv_input
|
||||
|
||||
if exists(mem):
|
||||
k_input = torch.cat((mem, k_input), dim=-2)
|
||||
v_input = torch.cat((mem, v_input), dim=-2)
|
||||
|
||||
if exists(sinusoidal_emb):
|
||||
# in shortformer, the query would start at a position offset depending on the past cached memory
|
||||
offset = k_input.shape[-2] - q_input.shape[-2]
|
||||
q_input = q_input + sinusoidal_emb(q_input, offset=offset)
|
||||
k_input = k_input + sinusoidal_emb(k_input)
|
||||
|
||||
q = self.to_q(q_input)
|
||||
k = self.to_k(k_input)
|
||||
v = self.to_v(v_input)
|
||||
|
||||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
||||
|
||||
input_mask = None
|
||||
if any(map(exists, (mask, context_mask))):
|
||||
q_mask = default(mask, lambda: torch.ones((b, n), device=device).bool())
|
||||
k_mask = q_mask if not exists(context) else context_mask
|
||||
k_mask = default(k_mask, lambda: torch.ones((b, k.shape[-2]), device=device).bool())
|
||||
q_mask = rearrange(q_mask, 'b i -> b () i ()')
|
||||
k_mask = rearrange(k_mask, 'b j -> b () () j')
|
||||
input_mask = q_mask * k_mask
|
||||
|
||||
if self.num_mem_kv > 0:
|
||||
mem_k, mem_v = map(lambda t: repeat(t, 'h n d -> b h n d', b=b), (self.mem_k, self.mem_v))
|
||||
k = torch.cat((mem_k, k), dim=-2)
|
||||
v = torch.cat((mem_v, v), dim=-2)
|
||||
if exists(input_mask):
|
||||
input_mask = F.pad(input_mask, (self.num_mem_kv, 0), value=True)
|
||||
|
||||
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
|
||||
mask_value = max_neg_value(dots)
|
||||
|
||||
if exists(prev_attn):
|
||||
dots = dots + prev_attn
|
||||
|
||||
pre_softmax_attn = dots
|
||||
|
||||
if talking_heads:
|
||||
dots = einsum('b h i j, h k -> b k i j', dots, self.pre_softmax_proj).contiguous()
|
||||
|
||||
if exists(rel_pos):
|
||||
dots = rel_pos(dots)
|
||||
|
||||
if exists(input_mask):
|
||||
dots.masked_fill_(~input_mask, mask_value)
|
||||
del input_mask
|
||||
|
||||
if self.causal:
|
||||
i, j = dots.shape[-2:]
|
||||
r = torch.arange(i, device=device)
|
||||
mask = rearrange(r, 'i -> () () i ()') < rearrange(r, 'j -> () () () j')
|
||||
mask = F.pad(mask, (j - i, 0), value=False)
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
if exists(self.sparse_topk) and self.sparse_topk < dots.shape[-1]:
|
||||
top, _ = dots.topk(self.sparse_topk, dim=-1)
|
||||
vk = top[..., -1].unsqueeze(-1).expand_as(dots)
|
||||
mask = dots < vk
|
||||
dots.masked_fill_(mask, mask_value)
|
||||
del mask
|
||||
|
||||
attn = self.attn_fn(dots, dim=-1)
|
||||
post_softmax_attn = attn
|
||||
|
||||
attn = self.dropout(attn)
|
||||
|
||||
if talking_heads:
|
||||
attn = einsum('b h i j, h k -> b k i j', attn, self.post_softmax_proj).contiguous()
|
||||
|
||||
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
||||
out = rearrange(out, 'b h n d -> b n (h d)')
|
||||
|
||||
intermediates = Intermediates(
|
||||
pre_softmax_attn=pre_softmax_attn,
|
||||
post_softmax_attn=post_softmax_attn
|
||||
)
|
||||
|
||||
return self.to_out(out), intermediates
|
||||
|
||||
|
||||
class AttentionLayers(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim,
|
||||
depth,
|
||||
heads=8,
|
||||
causal=False,
|
||||
cross_attend=False,
|
||||
only_cross=False,
|
||||
use_scalenorm=False,
|
||||
use_rmsnorm=False,
|
||||
use_rezero=False,
|
||||
rel_pos_num_buckets=32,
|
||||
rel_pos_max_distance=128,
|
||||
position_infused_attn=False,
|
||||
custom_layers=None,
|
||||
sandwich_coef=None,
|
||||
par_ratio=None,
|
||||
residual_attn=False,
|
||||
cross_residual_attn=False,
|
||||
macaron=False,
|
||||
pre_norm=True,
|
||||
gate_residual=False,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__()
|
||||
ff_kwargs, kwargs = groupby_prefix_and_trim('ff_', kwargs)
|
||||
attn_kwargs, _ = groupby_prefix_and_trim('attn_', kwargs)
|
||||
|
||||
dim_head = attn_kwargs.get('dim_head', DEFAULT_DIM_HEAD)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
self.layers = nn.ModuleList([])
|
||||
|
||||
self.has_pos_emb = position_infused_attn
|
||||
self.pia_pos_emb = FixedPositionalEmbedding(dim) if position_infused_attn else None
|
||||
self.rotary_pos_emb = always(None)
|
||||
|
||||
assert rel_pos_num_buckets <= rel_pos_max_distance, 'number of relative position buckets must be less than the relative position max distance'
|
||||
self.rel_pos = None
|
||||
|
||||
self.pre_norm = pre_norm
|
||||
|
||||
self.residual_attn = residual_attn
|
||||
self.cross_residual_attn = cross_residual_attn
|
||||
|
||||
norm_class = ScaleNorm if use_scalenorm else nn.LayerNorm
|
||||
norm_class = RMSNorm if use_rmsnorm else norm_class
|
||||
norm_fn = partial(norm_class, dim)
|
||||
|
||||
norm_fn = nn.Identity if use_rezero else norm_fn
|
||||
branch_fn = Rezero if use_rezero else None
|
||||
|
||||
if cross_attend and not only_cross:
|
||||
default_block = ('a', 'c', 'f')
|
||||
elif cross_attend and only_cross:
|
||||
default_block = ('c', 'f')
|
||||
else:
|
||||
default_block = ('a', 'f')
|
||||
|
||||
if macaron:
|
||||
default_block = ('f',) + default_block
|
||||
|
||||
if exists(custom_layers):
|
||||
layer_types = custom_layers
|
||||
elif exists(par_ratio):
|
||||
par_depth = depth * len(default_block)
|
||||
assert 1 < par_ratio <= par_depth, 'par ratio out of range'
|
||||
default_block = tuple(filter(not_equals('f'), default_block))
|
||||
par_attn = par_depth // par_ratio
|
||||
depth_cut = par_depth * 2 // 3 # 2 / 3 attention layer cutoff suggested by PAR paper
|
||||
par_width = (depth_cut + depth_cut // par_attn) // par_attn
|
||||
assert len(default_block) <= par_width, 'default block is too large for par_ratio'
|
||||
par_block = default_block + ('f',) * (par_width - len(default_block))
|
||||
par_head = par_block * par_attn
|
||||
layer_types = par_head + ('f',) * (par_depth - len(par_head))
|
||||
elif exists(sandwich_coef):
|
||||
assert sandwich_coef > 0 and sandwich_coef <= depth, 'sandwich coefficient should be less than the depth'
|
||||
layer_types = ('a',) * sandwich_coef + default_block * (depth - sandwich_coef) + ('f',) * sandwich_coef
|
||||
else:
|
||||
layer_types = default_block * depth
|
||||
|
||||
self.layer_types = layer_types
|
||||
self.num_attn_layers = len(list(filter(equals('a'), layer_types)))
|
||||
|
||||
for layer_type in self.layer_types:
|
||||
if layer_type == 'a':
|
||||
layer = Attention(dim, heads=heads, causal=causal, **attn_kwargs)
|
||||
elif layer_type == 'c':
|
||||
layer = Attention(dim, heads=heads, **attn_kwargs)
|
||||
elif layer_type == 'f':
|
||||
layer = FeedForward(dim, **ff_kwargs)
|
||||
layer = layer if not macaron else Scale(0.5, layer)
|
||||
else:
|
||||
raise Exception(f'invalid layer type {layer_type}')
|
||||
|
||||
if isinstance(layer, Attention) and exists(branch_fn):
|
||||
layer = branch_fn(layer)
|
||||
|
||||
if gate_residual:
|
||||
residual_fn = GRUGating(dim)
|
||||
else:
|
||||
residual_fn = Residual()
|
||||
|
||||
self.layers.append(nn.ModuleList([
|
||||
norm_fn(),
|
||||
layer,
|
||||
residual_fn
|
||||
]))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
context=None,
|
||||
mask=None,
|
||||
context_mask=None,
|
||||
mems=None,
|
||||
return_hiddens=False
|
||||
):
|
||||
hiddens = []
|
||||
intermediates = []
|
||||
prev_attn = None
|
||||
prev_cross_attn = None
|
||||
|
||||
mems = mems.copy() if exists(mems) else [None] * self.num_attn_layers
|
||||
|
||||
for ind, (layer_type, (norm, block, residual_fn)) in enumerate(zip(self.layer_types, self.layers)):
|
||||
is_last = ind == (len(self.layers) - 1)
|
||||
|
||||
if layer_type == 'a':
|
||||
hiddens.append(x)
|
||||
layer_mem = mems.pop(0)
|
||||
|
||||
residual = x
|
||||
|
||||
if self.pre_norm:
|
||||
x = norm(x)
|
||||
|
||||
if layer_type == 'a':
|
||||
out, inter = block(x, mask=mask, sinusoidal_emb=self.pia_pos_emb, rel_pos=self.rel_pos,
|
||||
prev_attn=prev_attn, mem=layer_mem)
|
||||
elif layer_type == 'c':
|
||||
out, inter = block(x, context=context, mask=mask, context_mask=context_mask, prev_attn=prev_cross_attn)
|
||||
elif layer_type == 'f':
|
||||
out = block(x)
|
||||
|
||||
x = residual_fn(out, residual)
|
||||
|
||||
if layer_type in ('a', 'c'):
|
||||
intermediates.append(inter)
|
||||
|
||||
if layer_type == 'a' and self.residual_attn:
|
||||
prev_attn = inter.pre_softmax_attn
|
||||
elif layer_type == 'c' and self.cross_residual_attn:
|
||||
prev_cross_attn = inter.pre_softmax_attn
|
||||
|
||||
if not self.pre_norm and not is_last:
|
||||
x = norm(x)
|
||||
|
||||
if return_hiddens:
|
||||
intermediates = LayerIntermediates(
|
||||
hiddens=hiddens,
|
||||
attn_intermediates=intermediates
|
||||
)
|
||||
|
||||
return x, intermediates
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(AttentionLayers):
|
||||
def __init__(self, **kwargs):
|
||||
assert 'causal' not in kwargs, 'cannot set causality on encoder'
|
||||
super().__init__(causal=False, **kwargs)
|
||||
|
||||
|
||||
|
||||
class TransformerWrapper(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
num_tokens,
|
||||
max_seq_len,
|
||||
attn_layers,
|
||||
emb_dim=None,
|
||||
max_mem_len=0.,
|
||||
emb_dropout=0.,
|
||||
num_memory_tokens=None,
|
||||
tie_embedding=False,
|
||||
use_pos_emb=True
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(attn_layers, AttentionLayers), 'attention layers must be one of Encoder or Decoder'
|
||||
|
||||
dim = attn_layers.dim
|
||||
emb_dim = default(emb_dim, dim)
|
||||
|
||||
self.max_seq_len = max_seq_len
|
||||
self.max_mem_len = max_mem_len
|
||||
self.num_tokens = num_tokens
|
||||
|
||||
self.token_emb = nn.Embedding(num_tokens, emb_dim)
|
||||
self.pos_emb = AbsolutePositionalEmbedding(emb_dim, max_seq_len) if (
|
||||
use_pos_emb and not attn_layers.has_pos_emb) else always(0)
|
||||
self.emb_dropout = nn.Dropout(emb_dropout)
|
||||
|
||||
self.project_emb = nn.Linear(emb_dim, dim) if emb_dim != dim else nn.Identity()
|
||||
self.attn_layers = attn_layers
|
||||
self.norm = nn.LayerNorm(dim)
|
||||
|
||||
self.init_()
|
||||
|
||||
self.to_logits = nn.Linear(dim, num_tokens) if not tie_embedding else lambda t: t @ self.token_emb.weight.t()
|
||||
|
||||
# memory tokens (like [cls]) from Memory Transformers paper
|
||||
num_memory_tokens = default(num_memory_tokens, 0)
|
||||
self.num_memory_tokens = num_memory_tokens
|
||||
if num_memory_tokens > 0:
|
||||
self.memory_tokens = nn.Parameter(torch.randn(num_memory_tokens, dim))
|
||||
|
||||
# let funnel encoder know number of memory tokens, if specified
|
||||
if hasattr(attn_layers, 'num_memory_tokens'):
|
||||
attn_layers.num_memory_tokens = num_memory_tokens
|
||||
|
||||
def init_(self):
|
||||
nn.init.normal_(self.token_emb.weight, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x,
|
||||
return_embeddings=False,
|
||||
mask=None,
|
||||
return_mems=False,
|
||||
return_attn=False,
|
||||
mems=None,
|
||||
**kwargs
|
||||
):
|
||||
b, n, device, num_mem = *x.shape, x.device, self.num_memory_tokens
|
||||
x = self.token_emb(x)
|
||||
x += self.pos_emb(x)
|
||||
x = self.emb_dropout(x)
|
||||
|
||||
x = self.project_emb(x)
|
||||
|
||||
if num_mem > 0:
|
||||
mem = repeat(self.memory_tokens, 'n d -> b n d', b=b)
|
||||
x = torch.cat((mem, x), dim=1)
|
||||
|
||||
# auto-handle masking after appending memory tokens
|
||||
if exists(mask):
|
||||
mask = F.pad(mask, (num_mem, 0), value=True)
|
||||
|
||||
x, intermediates = self.attn_layers(x, mask=mask, mems=mems, return_hiddens=True, **kwargs)
|
||||
x = self.norm(x)
|
||||
|
||||
mem, x = x[:, :num_mem], x[:, num_mem:]
|
||||
|
||||
out = self.to_logits(x) if not return_embeddings else x
|
||||
|
||||
if return_mems:
|
||||
hiddens = intermediates.hiddens
|
||||
new_mems = list(map(lambda pair: torch.cat(pair, dim=-2), zip(mems, hiddens))) if exists(mems) else hiddens
|
||||
new_mems = list(map(lambda t: t[..., -self.max_mem_len:, :].detach(), new_mems))
|
||||
return out, new_mems
|
||||
|
||||
if return_attn:
|
||||
attn_maps = list(map(lambda t: t.post_softmax_attn, intermediates.attn_intermediates))
|
||||
return out, attn_maps
|
||||
|
||||
return out
|
||||
|
||||
121
modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py
Normal file
121
modelscope/models/cv/image_to_3d/ldm/thirdp/psp/helpers.py
Normal file
@@ -0,0 +1,121 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from collections import namedtuple
|
||||
import torch
|
||||
from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module
|
||||
|
||||
"""
|
||||
ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Flatten(Module):
|
||||
def forward(self, input):
|
||||
return input.view(input.size(0), -1)
|
||||
|
||||
|
||||
def l2_norm(input, axis=1):
|
||||
norm = torch.norm(input, 2, axis, True)
|
||||
output = torch.div(input, norm)
|
||||
return output
|
||||
|
||||
|
||||
class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])):
|
||||
""" A named tuple describing a ResNet block. """
|
||||
|
||||
|
||||
def get_block(in_channel, depth, num_units, stride=2):
|
||||
return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)]
|
||||
|
||||
|
||||
def get_blocks(num_layers):
|
||||
if num_layers == 50:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=4),
|
||||
get_block(in_channel=128, depth=256, num_units=14),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 100:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=13),
|
||||
get_block(in_channel=128, depth=256, num_units=30),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
elif num_layers == 152:
|
||||
blocks = [
|
||||
get_block(in_channel=64, depth=64, num_units=3),
|
||||
get_block(in_channel=64, depth=128, num_units=8),
|
||||
get_block(in_channel=128, depth=256, num_units=36),
|
||||
get_block(in_channel=256, depth=512, num_units=3)
|
||||
]
|
||||
else:
|
||||
raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers))
|
||||
return blocks
|
||||
|
||||
|
||||
class SEModule(Module):
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = AdaptiveAvgPool2d(1)
|
||||
self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False)
|
||||
self.relu = ReLU(inplace=True)
|
||||
self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False)
|
||||
self.sigmoid = Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class bottleneck_IR(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
|
||||
|
||||
class bottleneck_IR_SE(Module):
|
||||
def __init__(self, in_channel, depth, stride):
|
||||
super(bottleneck_IR_SE, self).__init__()
|
||||
if in_channel == depth:
|
||||
self.shortcut_layer = MaxPool2d(1, stride)
|
||||
else:
|
||||
self.shortcut_layer = Sequential(
|
||||
Conv2d(in_channel, depth, (1, 1), stride, bias=False),
|
||||
BatchNorm2d(depth)
|
||||
)
|
||||
self.res_layer = Sequential(
|
||||
BatchNorm2d(in_channel),
|
||||
Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False),
|
||||
PReLU(depth),
|
||||
Conv2d(depth, depth, (3, 3), stride, 1, bias=False),
|
||||
BatchNorm2d(depth),
|
||||
SEModule(depth, 16)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
shortcut = self.shortcut_layer(x)
|
||||
res = self.res_layer(x)
|
||||
return res + shortcut
|
||||
23
modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py
Normal file
23
modelscope/models/cv/image_to_3d/ldm/thirdp/psp/id_loss.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
import torch
|
||||
from torch import nn
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.model_irse import Backbone
|
||||
|
||||
|
||||
class IDFeatures(nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super(IDFeatures, self).__init__()
|
||||
print('Loading ResNet ArcFace')
|
||||
self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se')
|
||||
self.facenet.load_state_dict(torch.load(model_path, map_location="cpu"))
|
||||
self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112))
|
||||
self.facenet.eval()
|
||||
|
||||
def forward(self, x, crop=False):
|
||||
# Not sure of the image range here
|
||||
if crop:
|
||||
x = torch.nn.functional.interpolate(x, (256, 256), mode="area")
|
||||
x = x[:, :, 35:223, 32:220]
|
||||
x = self.face_pool(x)
|
||||
x_feats = self.facenet(x)
|
||||
return x_feats
|
||||
@@ -0,0 +1,86 @@
|
||||
# https://github.com/eladrich/pixel2style2pixel
|
||||
|
||||
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module
|
||||
from modelscope.models.cv.image_to_3d.ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm
|
||||
|
||||
"""
|
||||
Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch)
|
||||
"""
|
||||
|
||||
|
||||
class Backbone(Module):
|
||||
def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True):
|
||||
super(Backbone, self).__init__()
|
||||
assert input_size in [112, 224], "input_size should be 112 or 224"
|
||||
assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152"
|
||||
assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se"
|
||||
blocks = get_blocks(num_layers)
|
||||
if mode == 'ir':
|
||||
unit_module = bottleneck_IR
|
||||
elif mode == 'ir_se':
|
||||
unit_module = bottleneck_IR_SE
|
||||
self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False),
|
||||
BatchNorm2d(64),
|
||||
PReLU(64))
|
||||
if input_size == 112:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 7 * 7, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
else:
|
||||
self.output_layer = Sequential(BatchNorm2d(512),
|
||||
Dropout(drop_ratio),
|
||||
Flatten(),
|
||||
Linear(512 * 14 * 14, 512),
|
||||
BatchNorm1d(512, affine=affine))
|
||||
|
||||
modules = []
|
||||
for block in blocks:
|
||||
for bottleneck in block:
|
||||
modules.append(unit_module(bottleneck.in_channel,
|
||||
bottleneck.depth,
|
||||
bottleneck.stride))
|
||||
self.body = Sequential(*modules)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.input_layer(x)
|
||||
x = self.body(x)
|
||||
x = self.output_layer(x)
|
||||
return l2_norm(x)
|
||||
|
||||
|
||||
def IR_50(input_size):
|
||||
"""Constructs a ir-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_101(input_size):
|
||||
"""Constructs a ir-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_152(input_size):
|
||||
"""Constructs a ir-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_50(input_size):
|
||||
"""Constructs a ir_se-50 model."""
|
||||
model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_101(input_size):
|
||||
"""Constructs a ir_se-101 model."""
|
||||
model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
|
||||
|
||||
def IR_SE_152(input_size):
|
||||
"""Constructs a ir_se-152 model."""
|
||||
model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False)
|
||||
return model
|
||||
276
modelscope/models/cv/image_to_3d/ldm/util.py
Normal file
276
modelscope/models/cv/image_to_3d/ldm/util.py
Normal file
@@ -0,0 +1,276 @@
|
||||
import importlib
|
||||
|
||||
import torchvision
|
||||
import torch
|
||||
from torch import optim
|
||||
import numpy as np
|
||||
|
||||
from inspect import isfunction
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from PIL import Image
|
||||
import torch
|
||||
import time
|
||||
import cv2
|
||||
import PIL
|
||||
|
||||
def pil_rectangle_crop(im):
|
||||
width, height = im.size # Get dimensions
|
||||
|
||||
if width <= height:
|
||||
left = 0
|
||||
right = width
|
||||
top = (height - width)/2
|
||||
bottom = (height + width)/2
|
||||
else:
|
||||
|
||||
top = 0
|
||||
bottom = height
|
||||
left = (width - height) / 2
|
||||
bottom = (width + height) / 2
|
||||
|
||||
# Crop the center of the image
|
||||
im = im.crop((left, top, right, bottom))
|
||||
return im
|
||||
|
||||
def add_margin(pil_img, color=0, size=256):
|
||||
width, height = pil_img.size
|
||||
result = Image.new(pil_img.mode, (size, size), color)
|
||||
result.paste(pil_img, ((size - width) // 2, (size - height) // 2))
|
||||
return result
|
||||
|
||||
|
||||
def create_carvekit_interface():
|
||||
from carvekit.api.high import HiInterface
|
||||
# Check doc strings for more information
|
||||
interface = HiInterface(object_type="object", # Can be "object" or "hairs-like".
|
||||
batch_size_seg=5,
|
||||
batch_size_matting=1,
|
||||
device='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
seg_mask_size=640, # Use 640 for Tracer B7 and 320 for U2Net
|
||||
matting_mask_size=2048,
|
||||
trimap_prob_threshold=231,
|
||||
trimap_dilation=30,
|
||||
trimap_erosion_iters=5,
|
||||
fp16=False)
|
||||
|
||||
return interface
|
||||
|
||||
|
||||
def load_and_preprocess(interface, input_im):
|
||||
'''
|
||||
:param input_im (PIL Image).
|
||||
:return image (H, W, 3) array in [0, 1].
|
||||
'''
|
||||
# See https://github.com/Ir1d/image-background-remove-tool
|
||||
image = input_im.convert('RGB')
|
||||
|
||||
image_without_background = interface([image])[0]
|
||||
image_without_background = np.array(image_without_background)
|
||||
est_seg = image_without_background > 127
|
||||
image = np.array(image)
|
||||
foreground = est_seg[:, : , -1].astype(np.bool_)
|
||||
image[~foreground] = [255., 255., 255.]
|
||||
x, y, w, h = cv2.boundingRect(foreground.astype(np.uint8))
|
||||
image = image[y:y+h, x:x+w, :]
|
||||
image = PIL.Image.fromarray(np.array(image))
|
||||
|
||||
# resize image such that long edge is 512
|
||||
image.thumbnail([200, 200], Image.LANCZOS)
|
||||
image = add_margin(image, (255, 255, 255), size=256)
|
||||
image = np.array(image)
|
||||
|
||||
return image
|
||||
|
||||
|
||||
def log_txt_as_img(wh, xc, size=10):
|
||||
# wh a tuple of (width, height)
|
||||
# xc a list of captions to plot
|
||||
b = len(xc)
|
||||
txts = list()
|
||||
for bi in range(b):
|
||||
txt = Image.new("RGB", wh, color="white")
|
||||
draw = ImageDraw.Draw(txt)
|
||||
font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
|
||||
nc = int(40 * (wh[0] / 256))
|
||||
lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
|
||||
|
||||
try:
|
||||
draw.text((0, 0), lines, fill="black", font=font)
|
||||
except UnicodeEncodeError:
|
||||
print("Cant encode string for logging. Skipping.")
|
||||
|
||||
txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
|
||||
txts.append(txt)
|
||||
txts = np.stack(txts)
|
||||
txts = torch.tensor(txts)
|
||||
return txts
|
||||
|
||||
|
||||
def ismap(x):
|
||||
if not isinstance(x, torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] > 3)
|
||||
|
||||
|
||||
def isimage(x):
|
||||
if not isinstance(x,torch.Tensor):
|
||||
return False
|
||||
return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def mean_flat(tensor):
|
||||
"""
|
||||
https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
|
||||
Take the mean over all non-batch dimensions.
|
||||
"""
|
||||
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
||||
|
||||
|
||||
def count_params(model, verbose=False):
|
||||
total_params = sum(p.numel() for p in model.parameters())
|
||||
if verbose:
|
||||
print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.")
|
||||
return total_params
|
||||
|
||||
|
||||
def instantiate_from_config(config):
|
||||
if not "target" in config:
|
||||
if config == '__is_first_stage__':
|
||||
return None
|
||||
elif config == "__is_unconditional__":
|
||||
return None
|
||||
raise KeyError("Expected key `target` to instantiate.")
|
||||
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||||
|
||||
|
||||
def get_obj_from_str(string, reload=False):
|
||||
module, cls = string.rsplit(".", 1)
|
||||
print(module)
|
||||
if reload:
|
||||
module_imp = importlib.import_module(module)
|
||||
importlib.reload(module_imp)
|
||||
return getattr(importlib.import_module(module, package=None), cls)
|
||||
|
||||
|
||||
class AdamWwithEMAandWings(optim.Optimizer):
|
||||
# credit to https://gist.github.com/crowsonkb/65f7265353f403714fce3b2595e0b298
|
||||
def __init__(self, params, lr=1.e-3, betas=(0.9, 0.999), eps=1.e-8, # TODO: check hyperparameters before using
|
||||
weight_decay=1.e-2, amsgrad=False, ema_decay=0.9999, # ema decay to match previous code
|
||||
ema_power=1., param_names=()):
|
||||
"""AdamW that saves EMA versions of the parameters."""
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError("Invalid learning rate: {}".format(lr))
|
||||
if not 0.0 <= eps:
|
||||
raise ValueError("Invalid epsilon value: {}".format(eps))
|
||||
if not 0.0 <= betas[0] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
|
||||
if not 0.0 <= betas[1] < 1.0:
|
||||
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
|
||||
if not 0.0 <= weight_decay:
|
||||
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
|
||||
if not 0.0 <= ema_decay <= 1.0:
|
||||
raise ValueError("Invalid ema_decay value: {}".format(ema_decay))
|
||||
defaults = dict(lr=lr, betas=betas, eps=eps,
|
||||
weight_decay=weight_decay, amsgrad=amsgrad, ema_decay=ema_decay,
|
||||
ema_power=ema_power, param_names=param_names)
|
||||
super().__init__(params, defaults)
|
||||
|
||||
def __setstate__(self, state):
|
||||
super().__setstate__(state)
|
||||
for group in self.param_groups:
|
||||
group.setdefault('amsgrad', False)
|
||||
|
||||
@torch.no_grad()
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
Args:
|
||||
closure (callable, optional): A closure that reevaluates the model
|
||||
and returns the loss.
|
||||
"""
|
||||
loss = None
|
||||
if closure is not None:
|
||||
with torch.enable_grad():
|
||||
loss = closure()
|
||||
|
||||
for group in self.param_groups:
|
||||
params_with_grad = []
|
||||
grads = []
|
||||
exp_avgs = []
|
||||
exp_avg_sqs = []
|
||||
ema_params_with_grad = []
|
||||
state_sums = []
|
||||
max_exp_avg_sqs = []
|
||||
state_steps = []
|
||||
amsgrad = group['amsgrad']
|
||||
beta1, beta2 = group['betas']
|
||||
ema_decay = group['ema_decay']
|
||||
ema_power = group['ema_power']
|
||||
|
||||
for p in group['params']:
|
||||
if p.grad is None:
|
||||
continue
|
||||
params_with_grad.append(p)
|
||||
if p.grad.is_sparse:
|
||||
raise RuntimeError('AdamW does not support sparse gradients')
|
||||
grads.append(p.grad)
|
||||
|
||||
state = self.state[p]
|
||||
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
state['step'] = 0
|
||||
# Exponential moving average of gradient values
|
||||
state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of squared gradient values
|
||||
state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
if amsgrad:
|
||||
# Maintains max of all exp. moving avg. of sq. grad. values
|
||||
state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
# Exponential moving average of parameter values
|
||||
state['param_exp_avg'] = p.detach().float().clone()
|
||||
|
||||
exp_avgs.append(state['exp_avg'])
|
||||
exp_avg_sqs.append(state['exp_avg_sq'])
|
||||
ema_params_with_grad.append(state['param_exp_avg'])
|
||||
|
||||
if amsgrad:
|
||||
max_exp_avg_sqs.append(state['max_exp_avg_sq'])
|
||||
|
||||
# update the steps for each param group update
|
||||
state['step'] += 1
|
||||
# record the step after step update
|
||||
state_steps.append(state['step'])
|
||||
|
||||
optim._functional.adamw(params_with_grad,
|
||||
grads,
|
||||
exp_avgs,
|
||||
exp_avg_sqs,
|
||||
max_exp_avg_sqs,
|
||||
state_steps,
|
||||
amsgrad=amsgrad,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
lr=group['lr'],
|
||||
weight_decay=group['weight_decay'],
|
||||
eps=group['eps'],
|
||||
maximize=False)
|
||||
|
||||
cur_ema_decay = min(ema_decay, 1 - state['step'] ** -ema_power)
|
||||
for param, ema_param in zip(params_with_grad, ema_params_with_grad):
|
||||
ema_param.mul_(cur_ema_decay).add_(param.float(), alpha=1 - cur_ema_decay)
|
||||
|
||||
return loss
|
||||
@@ -28,6 +28,8 @@ class DocumentGroundedDialogRetrievalModel(TorchModel):
|
||||
map_location='cpu')
|
||||
compatible_position_ids(state_dict,
|
||||
'ctx_encoder.encoder.embeddings.position_ids')
|
||||
compatible_position_ids(state_dict,
|
||||
'qry_encoder.encoder.embeddings.position_ids')
|
||||
self.model.load_state_dict(state_dict)
|
||||
|
||||
def forward(self, input: Dict[str, Tensor], gck_segment=32):
|
||||
|
||||
@@ -49,6 +49,7 @@ class MsModelMixin:
|
||||
The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
|
||||
"""
|
||||
model_dir = kwargs.pop('model_dir', None)
|
||||
device = kwargs.pop('device', None)
|
||||
if model_dir is None:
|
||||
config = LlamaConfig(**kwargs)
|
||||
model = cls(config)
|
||||
@@ -56,7 +57,8 @@ class MsModelMixin:
|
||||
model = super(MsModelMixin, cls).from_pretrained(
|
||||
pretrained_model_name_or_path=model_dir, **kwargs)
|
||||
model.model_dir = model_dir
|
||||
return model
|
||||
return model if 'device_map' in kwargs \
|
||||
or device is None else model.to(device)
|
||||
|
||||
|
||||
class LlamaPreTrainedModel(MsModelMixin, LlamaPreTrainedModelHF, TorchModel):
|
||||
|
||||
@@ -375,6 +375,8 @@ class ProtoNet(nn.Module):
|
||||
input_ids = torch.IntTensor(input_ids)
|
||||
if not isinstance(input_mask, Tensor):
|
||||
input_mask = torch.IntTensor(input_mask)
|
||||
input_ids = input_ids.to(self.bert.device)
|
||||
input_mask = input_mask.to(self.bert.device)
|
||||
rst = self.bert(input_ids, input_mask)
|
||||
last_hidden_states = rst.last_hidden_state
|
||||
if len(input_mask.shape) == 2:
|
||||
|
||||
@@ -69,6 +69,7 @@ class OutputKeys(object):
|
||||
PCD12 = 'pcd12'
|
||||
PCD12_ALIGN = 'pcd12_align'
|
||||
TBOUNDS = 'tbounds'
|
||||
MV_IMGS = 'MViews'
|
||||
|
||||
|
||||
OutputTypes = {
|
||||
@@ -132,6 +133,7 @@ OutputTypes = {
|
||||
OutputKeys.PCD12: np.ndarray,
|
||||
OutputKeys.PCD12_ALIGN: np.ndarray,
|
||||
OutputKeys.TBOUNDS: Dict,
|
||||
OutputKeys.MV_IMGS: List[np.ndarray],
|
||||
}
|
||||
|
||||
OutputTypeSchema = {
|
||||
@@ -426,6 +428,15 @@ OutputTypeSchema = {
|
||||
OutputKeys.TBOUNDS: {
|
||||
'type': 'object'
|
||||
},
|
||||
OutputKeys.MV_IMGS: {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'array',
|
||||
'items': {
|
||||
'type': 'number'
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
TASK_OUTPUTS = {
|
||||
@@ -1632,6 +1643,7 @@ TASK_OUTPUTS = {
|
||||
# "output_imgs": np.ndarray list with shape [[height, width, 3], ...]
|
||||
# }
|
||||
Tasks.image_view_transform: [OutputKeys.OUTPUT_IMGS],
|
||||
Tasks.image_to_3d: [OutputKeys.MV_IMGS]
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -247,8 +247,10 @@ TASK_INPUTS = {
|
||||
InputType.VIDEO,
|
||||
|
||||
# image generation task result for a single image
|
||||
Tasks.image_to_image_generation:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_to_image_generation: [
|
||||
InputType.IMAGE,
|
||||
(InputType.IMAGE, InputType.IMAGE, InputType.IMAGE, InputType.IMAGE)
|
||||
],
|
||||
Tasks.image_to_image_translation:
|
||||
InputType.IMAGE,
|
||||
Tasks.image_style_transfer: {
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user