mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-25 12:39:25 +01:00
Merge branch 'master' into release/1.17
This commit is contained in:
10
README.md
10
README.md
@@ -299,3 +299,13 @@ We provide additional documentations including:
|
||||
# License
|
||||
|
||||
This project is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
|
||||
|
||||
# Citation
|
||||
```
|
||||
@Misc{modelscope,
|
||||
title = {ModelScope: bring the notion of Model-as-a-Service to life.},
|
||||
author = {The ModelScope Team},
|
||||
howpublished = {\url{https://github.com/modelscope/modelscope}},
|
||||
year = {2023}
|
||||
}
|
||||
```
|
||||
|
||||
147
README_zh.md
147
README_zh.md
@@ -26,26 +26,24 @@
|
||||
<h4 align="center">
|
||||
<p>
|
||||
<a href="https://github.com/modelscope/modelscope/blob/master/README.md">English</a> |
|
||||
<b>中文</b> |
|
||||
<a href="https://github.com/modelscope/modelscope/blob/master/README_ja.md">日本語</a>
|
||||
<b> 中文 </b> |
|
||||
<a href="https://github.com/modelscope/modelscope/blob/master/README_ja.md"> 日本語 </a>
|
||||
<p>
|
||||
</h4>
|
||||
|
||||
|
||||
</div>
|
||||
|
||||
# 简介
|
||||
|
||||
[ModelScope]( https://www.modelscope.cn) 是一个“模型即服务”(MaaS)平台,旨在汇集来自AI社区的最先进的机器学习模型,并简化在实际应用中使用AI模型的流程。ModelScope库使开发人员能够通过丰富的API设计执行推理、训练和评估,从而促进跨不同AI领域的最先进模型的统一体验。
|
||||
|
||||
ModelScope Library为模型贡献者提供了必要的分层API,以便将来自 CV、NLP、语音、多模态以及科学计算的模型集成到ModelScope生态系统中。所有这些不同模型的实现都以一种简单统一访问的方式进行封装,用户只需几行代码即可完成模型推理、微调和评估。同时,灵活的模块化设计使得在必要时也可以自定义模型训练推理过程中的不同组件。
|
||||
|
||||
除了包含各种模型的实现之外,ModelScope Library还支持与ModelScope后端服务进行必要的交互,特别是与Model-Hub和Dataset-Hub的交互。这种交互促进了模型和数据集的管理在后台无缝执行,包括模型数据集查询、版本控制、缓存管理等。
|
||||
[ModelScope](https://www.modelscope.cn) 是一个 “模型即服务”(MaaS) 平台,旨在汇集来自 AI 社区的最先进的机器学习模型,并简化在实际应用中使用 AI 模型的流程。ModelScope 库使开发人员能够通过丰富的 API 设计执行推理、训练和评估,从而促进跨不同 AI 领域的最先进模型的统一体验。
|
||||
|
||||
ModelScope Library 为模型贡献者提供了必要的分层 API,以便将来自 CV、NLP、语音、多模态以及科学计算的模型集成到 ModelScope 生态系统中。所有这些不同模型的实现都以一种简单统一访问的方式进行封装,用户只需几行代码即可完成模型推理、微调和评估。同时,灵活的模块化设计使得在必要时也可以自定义模型训练推理过程中的不同组件。
|
||||
|
||||
除了包含各种模型的实现之外,ModelScope Library 还支持与 ModelScope 后端服务进行必要的交互,特别是与 Model-Hub 和 Dataset-Hub 的交互。这种交互促进了模型和数据集的管理在后台无缝执行,包括模型数据集查询、版本控制、缓存管理等。
|
||||
|
||||
# 部分模型和在线体验
|
||||
ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计算机视觉、语音、多模态、科学计算等,其中包含数百个SOTA模型。用户可以进入ModelScope网站([modelscope.cn](http://www.modelscope.cn))的模型中心零门槛在线体验,或者Notebook方式体验模型。
|
||||
|
||||
ModelScope 开源了数百个 (当前 700+) 模型,涵盖自然语言处理、计算机视觉、语音、多模态、科学计算等,其中包含数百个 SOTA 模型。用户可以进入 ModelScope 网站 ([modelscope.cn](http://www.modelscope.cn)) 的模型中心零门槛在线体验,或者 Notebook 方式体验模型。
|
||||
|
||||
<p align="center">
|
||||
<br>
|
||||
@@ -69,7 +67,6 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
|
||||
* [Phi-3-mini-128k-instruct](https://modelscope.cn/models/LLM-Research/Phi-3-mini-128k-instruct/summary)
|
||||
|
||||
|
||||
多模态:
|
||||
|
||||
* [Qwen-VL-Chat](https://modelscope.cn/models/qwen/Qwen-VL-Chat/summary)
|
||||
@@ -88,37 +85,33 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
|
||||
计算机视觉:
|
||||
|
||||
* [DamoFD人脸检测关键点模型-0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
|
||||
* [DamoFD 人脸检测关键点模型-0.5G](https://modelscope.cn/models/damo/cv_ddsar_face-detection_iclr23-damofd/summary)
|
||||
|
||||
* [BSHM人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
|
||||
* [BSHM 人像抠图](https://modelscope.cn/models/damo/cv_unet_image-matting/summary)
|
||||
|
||||
* [DCT-Net人像卡通化-3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
|
||||
* [DCT-Net 人像卡通化-3D](https://modelscope.cn/models/damo/cv_unet_person-image-cartoon-3d_compound-models/summary)
|
||||
|
||||
* [DCT-Net人像卡通化模型-3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
|
||||
* [DCT-Net 人像卡通化模型-3D](https://modelscope.cn/models/damo/face_chain_control_model/summary)
|
||||
|
||||
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_convnextTiny_ocr-recognition-general_damo/summary)
|
||||
|
||||
* [读光-文字识别-行识别模型-中英-通用领域](https://modelscope.cn/models/damo/cv_resnet18_ocr-detection-line-level_damo/summary)
|
||||
|
||||
* [LaMa图像填充](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
|
||||
|
||||
|
||||
|
||||
* [LaMa 图像填充](https://modelscope.cn/models/damo/cv_fft_inpainting_lama/summary)
|
||||
|
||||
语音:
|
||||
|
||||
* [Paraformer语音识别-中文-通用-16k-离线-大型-长音频版本](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
* [Paraformer 语音识别-中文-通用-16k-离线-大型-长音频版本](https://modelscope.cn/models/damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary)
|
||||
|
||||
* [FSMN声音端点检测-中文-通用-16k-onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
|
||||
* [FSMN 声音端点检测-中文-通用-16k-onnx](https://modelscope.cn/models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx/summary)
|
||||
|
||||
* [Monotonic-Aligner语音时间戳预测-16k-离线](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
|
||||
* [Monotonic-Aligner 语音时间戳预测-16k-离线](https://modelscope.cn/models/damo/speech_timestamp_prediction-v1-16k-offline/summary)
|
||||
|
||||
* [CT-Transformer标点-中文-通用-onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
|
||||
* [CT-Transformer 标点-中文-通用-onnx](https://modelscope.cn/models/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx/summary)
|
||||
|
||||
* [语音合成-中文-多情绪领域-16k-多发言人](https://modelscope.cn/models/damo/speech_sambert-hifigan_tts_zh-cn_16k/summary)
|
||||
|
||||
* [CAM++说话人验证-中文-通用-200k发言人](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
|
||||
|
||||
* [CAM++ 说话人验证-中文-通用-200k-发言人](https://modelscope.cn/models/damo/speech_campplus_sv_zh-cn_16k-common/summary)
|
||||
|
||||
科学计算:
|
||||
|
||||
@@ -128,14 +121,15 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
|
||||
# 快速上手
|
||||
|
||||
我们针对不同任务提供了统一的使用接口, 使用`pipeline`进行模型推理、使用`Trainer`进行微调和评估。
|
||||
我们针对不同任务提供了统一的使用接口, 使用 `pipeline` 进行模型推理、使用 `Trainer` 进行微调和评估。
|
||||
|
||||
对于任意类型输入(图像、文本、音频、视频...)的任何任务,只需 3 行代码即可加载模型并获得推理结果,如下所示:
|
||||
|
||||
对于任意类型输入(图像、文本、音频、视频...)的任何任务,只需3行代码即可加载模型并获得推理结果,如下所示:
|
||||
```python
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
>>> word_segmentation = pipeline('word-segmentation',model='damo/nlp_structbert_word-segmentation_chinese-base')
|
||||
>>> word_segmentation('今天天气不错,适合出去游玩')
|
||||
{'output': '今天 天气 不错 , 适合 出去 游玩'}
|
||||
>>> word_segmentation = pipeline ('word-segmentation',model='damo/nlp_structbert_word-segmentation_chinese-base')
|
||||
>>> word_segmentation (' 今天天气不错,适合出去游玩 ')
|
||||
{'output': ' 今天 天气 不错 , 适合 出去 游玩 '}
|
||||
```
|
||||
|
||||
给定一张图片,你可以使用如下代码进行人像抠图.
|
||||
@@ -146,42 +140,44 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
>>> import cv2
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> portrait_matting = pipeline('portrait-matting')
|
||||
>>> result = portrait_matting('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matting.png')
|
||||
>>> cv2.imwrite('result.png', result['output_img'])
|
||||
>>> portrait_matting = pipeline ('portrait-matting')
|
||||
>>> result = portrait_matting ('https://modelscope.oss-cn-beijing.aliyuncs.com/test/images/image_matting.png')
|
||||
>>> cv2.imwrite ('result.png', result ['output_img'])
|
||||
```
|
||||
|
||||
输出图像如下
|
||||

|
||||
|
||||
对于微调和评估模型, 你需要通过十多行代码构建dataset和trainer,调用`trainer.train()`和`trainer.evaluate()`即可。
|
||||
对于微调和评估模型, 你需要通过十多行代码构建 dataset 和 trainer,调用 `trainer.train ()` 和 `trainer.evaluate ()` 即可。
|
||||
|
||||
例如我们利用 gpt3 1.3B 的模型,加载是诗歌数据集进行 finetune,可以完成古诗生成模型的训练。
|
||||
|
||||
例如我们利用gpt3 1.3B的模型,加载是诗歌数据集进行finetune,可以完成古诗生成模型的训练。
|
||||
```python
|
||||
>>> from modelscope.metainfo import Trainers
|
||||
>>> from modelscope.msdatasets import MsDataset
|
||||
>>> from modelscope.trainers import build_trainer
|
||||
|
||||
>>> train_dataset = MsDataset.load('chinese-poetry-collection', split='train'). remap_columns({'text1': 'src_txt'})
|
||||
>>> eval_dataset = MsDataset.load('chinese-poetry-collection', split='test').remap_columns({'text1': 'src_txt'})
|
||||
>>> train_dataset = MsDataset.load ('chinese-poetry-collection', split='train'). remap_columns ({'text1': 'src_txt'})
|
||||
>>> eval_dataset = MsDataset.load ('chinese-poetry-collection', split='test').remap_columns ({'text1': 'src_txt'})
|
||||
>>> max_epochs = 10
|
||||
>>> tmp_dir = './gpt3_poetry'
|
||||
|
||||
>>> kwargs = dict(
|
||||
>>> kwargs = dict (
|
||||
model='damo/nlp_gpt3_text-generation_1.3B',
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
max_epochs=max_epochs,
|
||||
work_dir=tmp_dir)
|
||||
|
||||
>>> trainer = build_trainer(name=Trainers.gpt3_trainer, default_args=kwargs)
|
||||
>>> trainer.train()
|
||||
>>> trainer = build_trainer (name=Trainers.gpt3_trainer, default_args=kwargs)
|
||||
>>> trainer.train ()
|
||||
```
|
||||
|
||||
# 为什么要用ModelScope Library
|
||||
# 为什么要用 ModelScope Library
|
||||
|
||||
1. 针对不同任务、不同模型抽象了统一简洁的用户接口,3行代码完成推理,10行代码完成模型训练,方便用户使用ModelScope社区中多个领域的不同模型,开箱即用,便于AI入门和教学。
|
||||
1. 针对不同任务、不同模型抽象了统一简洁的用户接口,3 行代码完成推理,10 行代码完成模型训练,方便用户使用 ModelScope 社区中多个领域的不同模型,开箱即用,便于 AI 入门和教学。
|
||||
|
||||
2. 构造以模型为中心的开发应用体验,支持模型训练、推理、导出部署,方便用户基于ModelScope Library构建自己的MLOps.
|
||||
2. 构造以模型为中心的开发应用体验,支持模型训练、推理、导出部署,方便用户基于 ModelScope Library 构建自己的 MLOps.
|
||||
|
||||
3. 针对模型推理、训练流程,进行了模块化的设计,并提供了丰富的功能模块实现,方便用户定制化开发来自定义自己的推理、训练等过程。
|
||||
|
||||
@@ -190,11 +186,13 @@ ModelScope开源了数百个(当前700+)模型,涵盖自然语言处理、计
|
||||
# 安装
|
||||
|
||||
## 镜像
|
||||
ModelScope Library目前支持tensorflow,pytorch深度学习框架进行模型训练、推理, 在Python 3.7+, Pytorch 1.8+, Tensorflow1.15/Tensorflow2.0+测试可运行。
|
||||
|
||||
为了让大家能直接用上ModelScope平台上的所有模型,无需配置环境,ModelScope提供了官方镜像,方便有需要的开发者获取。地址如下:
|
||||
ModelScope Library 目前支持 tensorflow,pytorch 深度学习框架进行模型训练、推理, 在 Python 3.7+, Pytorch 1.8+, Tensorflow1.15/Tensorflow2.0 + 测试可运行。
|
||||
|
||||
为了让大家能直接用上 ModelScope 平台上的所有模型,无需配置环境,ModelScope 提供了官方镜像,方便有需要的开发者获取。地址如下:
|
||||
|
||||
CPU 镜像
|
||||
|
||||
CPU镜像
|
||||
```shell
|
||||
# py37
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
@@ -203,7 +201,8 @@ registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py37-to
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
GPU镜像
|
||||
GPU 镜像
|
||||
|
||||
```shell
|
||||
# py37
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.3.0-py37-torch1.11.0-tf1.15.5-1.6.1
|
||||
@@ -212,81 +211,91 @@ registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.
|
||||
registry.cn-hangzhou.aliyuncs.com/modelscope-repo/modelscope:ubuntu20.04-cuda11.8.0-py38-torch2.0.1-tf2.13.0-1.9.5
|
||||
```
|
||||
|
||||
## 搭建本地Python环境
|
||||
## 搭建本地 Python 环境
|
||||
|
||||
你也可以使用 pip 和 conda 搭建本地 python 环境,ModelScope 支持 python3.7 + 以上环境,我们推荐使用 [Anaconda](https://docs.anaconda.com/anaconda/install/),安装完成后,执行如下命令为 modelscope library 创建对应的 python 环境:
|
||||
|
||||
你也可以使用pip和conda搭建本地python环境,ModelScope支持python3.7+以上环境,我们推荐使用[Anaconda](https://docs.anaconda.com/anaconda/install/),安装完成后,执行如下命令为modelscope library创建对应的python环境:
|
||||
```shell
|
||||
conda create -n modelscope python=3.8
|
||||
conda activate modelscope
|
||||
```
|
||||
|
||||
接下来根据所需使用的模型依赖安装底层计算框架
|
||||
* 安装Pytorch [文档链接](https://pytorch.org/get-started/locally/)
|
||||
* 安装tensorflow [文档链接](https://www.tensorflow.org/install/pip)
|
||||
|
||||
* 安装 Pytorch [文档链接](https://pytorch.org/get-started/locally/)
|
||||
* 安装 tensorflow [文档链接](https://www.tensorflow.org/install/pip)
|
||||
|
||||
安装完前置依赖,你可以按照如下方式安装ModelScope Library。
|
||||
安装完前置依赖,你可以按照如下方式安装 ModelScope Library。
|
||||
|
||||
ModelScope Libarary 由核心框架,以及不同领域模型的对接组件组成。如果只需要 ModelScope 模型和数据集访问等基础能力,可以只安装 ModelScope 的核心框架:
|
||||
|
||||
ModelScope Libarary由核心框架,以及不同领域模型的对接组件组成。如果只需要ModelScope模型和数据集访问等基础能力,可以只安装ModelScope的核心框架:
|
||||
```shell
|
||||
pip install modelscope
|
||||
```
|
||||
|
||||
如仅需体验多模态领域的模型,可执行如下命令安装领域依赖:
|
||||
|
||||
```shell
|
||||
pip install modelscope[multi-modal]
|
||||
pip install modelscope [multi-modal]
|
||||
```
|
||||
|
||||
如仅需体验NLP领域模型,可执行如下命令安装领域依赖(因部分依赖由ModelScope独立host,所以需要使用"-f"参数):
|
||||
如仅需体验 NLP 领域模型,可执行如下命令安装领域依赖(因部分依赖由 ModelScope 独立 host,所以需要使用 "-f" 参数):
|
||||
|
||||
```shell
|
||||
pip install modelscope[nlp] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install modelscope [nlp] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
```
|
||||
|
||||
如仅需体验计算机视觉领域的模型,可执行如下命令安装领域依赖(因部分依赖由ModelScope独立host,所以需要使用"-f"参数):
|
||||
如仅需体验计算机视觉领域的模型,可执行如下命令安装领域依赖(因部分依赖由 ModelScope 独立 host,所以需要使用 "-f" 参数):
|
||||
|
||||
```shell
|
||||
pip install modelscope[cv] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install modelscope [cv] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
```
|
||||
|
||||
如仅需体验语音领域模型,可执行如下命令安装领域依赖(因部分依赖由ModelScope独立host,所以需要使用"-f"参数):
|
||||
如仅需体验语音领域模型,可执行如下命令安装领域依赖(因部分依赖由 ModelScope 独立 host,所以需要使用 "-f" 参数):
|
||||
|
||||
```shell
|
||||
pip install modelscope[audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install modelscope [audio] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
```
|
||||
|
||||
`注意`:当前大部分语音模型需要在Linux环境上使用,并且推荐使用python3.7 + tensorflow 1.x的组合。
|
||||
`注意`:当前大部分语音模型需要在 Linux 环境上使用,并且推荐使用 python3.7 + tensorflow 1.x 的组合。
|
||||
|
||||
如仅需体验科学计算领域模型,可执行如下命令安装领域依赖(因部分依赖由 ModelScope 独立 host,所以需要使用 "-f" 参数):
|
||||
|
||||
如仅需体验科学计算领域模型,可执行如下命令安装领域依赖(因部分依赖由ModelScope独立host,所以需要使用"-f"参数):
|
||||
```shell
|
||||
pip install modelscope[science] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
pip install modelscope [science] -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html
|
||||
```
|
||||
|
||||
`注`:
|
||||
1. 目前部分语音相关的模型仅支持 python3.7,tensorflow1.15.4的Linux环境使用。 其他绝大部分模型可以在windows、mac(x86)上安装使用。.
|
||||
`注意`:
|
||||
|
||||
1. 目前部分语音相关的模型仅支持 python3.7,tensorflow1.15.4 的 Linux 环境使用。 其他绝大部分模型可以在 windows、mac(x86)上安装使用。
|
||||
|
||||
2. 语音领域中一部分模型使用了三方库 SoundFile 进行 wav 文件处理,在 Linux 系统上用户需要手动安装 SoundFile 的底层依赖库 libsndfile,在 Windows 和 MacOS 上会自动安装不需要用户操作。详细信息可参考 [SoundFile 官网](https://github.com/bastibe/python-soundfile#installation)。以 Ubuntu 系统为例,用户需要执行如下命令:
|
||||
|
||||
2. 语音领域中一部分模型使用了三方库SoundFile进行wav文件处理,在Linux系统上用户需要手动安装SoundFile的底层依赖库libsndfile,在Windows和MacOS上会自动安装不需要用户操作。详细信息可参考[SoundFile 官网](https://github.com/bastibe/python-soundfile#installation)。以Ubuntu系统为例,用户需要执行如下命令:
|
||||
```shell
|
||||
sudo apt-get update
|
||||
sudo apt-get install libsndfile1
|
||||
```
|
||||
|
||||
3. CV领域的少数模型,需要安装mmcv-full, 如果运行过程中提示缺少mmcv,请参考mmcv[安装手册](https://github.com/open-mmlab/mmcv#installation)进行安装。 这里提供一个最简版的mmcv-full安装步骤,但是要达到最优的mmcv-full的安装效果(包括对于cuda版本的兼容),请根据自己的实际机器环境,以mmcv官方安装手册为准。
|
||||
3. CV 领域的少数模型,需要安装 mmcv-full, 如果运行过程中提示缺少 mmcv,请参考 mmcv [安装手册](https://github.com/open-mmlab/mmcv#installation) 进行安装。 这里提供一个最简版的 mmcv-full 安装步骤,但是要达到最优的 mmcv-full 的安装效果(包括对于 cuda 版本的兼容),请根据自己的实际机器环境,以 mmcv 官方安装手册为准。
|
||||
|
||||
```shell
|
||||
pip uninstall mmcv # if you have installed mmcv, uninstall it
|
||||
pip install -U openmim
|
||||
mim install mmcv-full
|
||||
```
|
||||
|
||||
|
||||
# 更多教程
|
||||
|
||||
除了上述内容,我们还提供如下信息:
|
||||
|
||||
* [更加详细的安装文档](https://modelscope.cn/docs/%E7%8E%AF%E5%A2%83%E5%AE%89%E8%A3%85)
|
||||
* [任务的介绍](https://modelscope.cn/docs/%E4%BB%BB%E5%8A%A1%E7%9A%84%E4%BB%8B%E7%BB%8D)
|
||||
* [模型推理](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E6%8E%A8%E7%90%86Pipeline)
|
||||
* [模型微调](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AE%AD%E7%BB%83Train)
|
||||
* [数据预处理](https://modelscope.cn/docs/%E6%95%B0%E6%8D%AE%E7%9A%84%E9%A2%84%E5%A4%84%E7%90%86)
|
||||
* [模型评估](https://modelscope.cn/docs/%E6%A8%A1%E5%9E%8B%E7%9A%84%E8%AF%84%E4%BC%B0)
|
||||
* [贡献模型到ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88)
|
||||
* [贡献模型到 ModelScope](https://modelscope.cn/docs/ModelScope%E6%A8%A1%E5%9E%8B%E6%8E%A5%E5%85%A5%E6%B5%81%E7%A8%8B%E6%A6%82%E8%A7%88)
|
||||
|
||||
# License
|
||||
|
||||
本项目使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
|
||||
本项目使用 [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE).
|
||||
|
||||
Submodule data/test updated: 7a7f6b8d05...dedb3ce447
@@ -196,10 +196,10 @@ def _repo_file_download(
|
||||
if cookies is None:
|
||||
cookies = ModelScopeConfig.get_cookies()
|
||||
repo_files = []
|
||||
file_to_download_meta = None
|
||||
if repo_type == REPO_TYPE_MODEL:
|
||||
revision = _api.get_valid_revision(
|
||||
repo_id, revision=revision, cookies=cookies)
|
||||
file_to_download_meta = None
|
||||
# we need to confirm the version is up-to-date
|
||||
# we need to get the file list to check if the latest version is cached, if so return, otherwise download
|
||||
repo_files = _api.get_model_files(
|
||||
@@ -207,38 +207,60 @@ def _repo_file_download(
|
||||
revision=revision,
|
||||
recursive=True,
|
||||
use_cookies=False if cookies is None else cookies)
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if repo_file['Path'] == file_path:
|
||||
if cache.exists(repo_file):
|
||||
logger.debug(
|
||||
f'File {repo_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(repo_file)
|
||||
else:
|
||||
file_to_download_meta = repo_file
|
||||
break
|
||||
elif repo_type == REPO_TYPE_DATASET:
|
||||
group_or_owner, name = model_id_to_group_owner_name(repo_id)
|
||||
if not revision:
|
||||
revision = DEFAULT_DATASET_REVISION
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True)
|
||||
if not ('Code' in files_list_tree and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
page_number = 1
|
||||
page_size = 100
|
||||
while True:
|
||||
files_list_tree = _api.list_repo_tree(
|
||||
dataset_name=name,
|
||||
namespace=group_or_owner,
|
||||
revision=revision,
|
||||
root_path='/',
|
||||
recursive=True,
|
||||
page_number=page_number,
|
||||
page_size=page_size)
|
||||
if not ('Code' in files_list_tree
|
||||
and files_list_tree['Code'] == 200):
|
||||
print(
|
||||
'Get dataset: %s file list failed, request_id: %s, message: %s'
|
||||
% (repo_id, files_list_tree['RequestId'],
|
||||
files_list_tree['Message']))
|
||||
return None
|
||||
repo_files = files_list_tree['Data']['Files']
|
||||
is_exist = False
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
file_to_download_meta = None
|
||||
for repo_file in repo_files:
|
||||
if repo_file['Type'] == 'tree':
|
||||
continue
|
||||
|
||||
if repo_file['Path'] == file_path:
|
||||
if cache.exists(repo_file):
|
||||
logger.debug(
|
||||
f'File {repo_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(repo_file)
|
||||
else:
|
||||
file_to_download_meta = repo_file
|
||||
break
|
||||
if repo_file['Path'] == file_path:
|
||||
if cache.exists(repo_file):
|
||||
logger.debug(
|
||||
f'File {repo_file["Name"]} already in cache, skip downloading!'
|
||||
)
|
||||
return cache.get_file_by_info(repo_file)
|
||||
else:
|
||||
file_to_download_meta = repo_file
|
||||
is_exist = True
|
||||
break
|
||||
if len(repo_files) < page_size or is_exist:
|
||||
break
|
||||
page_number += 1
|
||||
|
||||
if file_to_download_meta is None:
|
||||
raise NotExistError('The file path: %s not exist in: %s' %
|
||||
|
||||
@@ -81,6 +81,10 @@ def snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
model_id,
|
||||
repo_type=REPO_TYPE_MODEL,
|
||||
@@ -155,6 +159,10 @@ def dataset_snapshot_download(
|
||||
- [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
|
||||
if some parameter value is invalid
|
||||
"""
|
||||
if allow_patterns:
|
||||
allow_file_pattern = allow_patterns
|
||||
if ignore_patterns:
|
||||
ignore_file_pattern = ignore_patterns
|
||||
return _snapshot_download(
|
||||
dataset_id,
|
||||
repo_type=REPO_TYPE_DATASET,
|
||||
|
||||
@@ -58,6 +58,7 @@ class Models(object):
|
||||
s2net_depth_estimation = 's2net-depth-estimation'
|
||||
dro_resnet18_depth_estimation = 'dro-resnet18-depth-estimation'
|
||||
raft_dense_optical_flow_estimation = 'raft-dense-optical-flow-estimation'
|
||||
human_normal_estimation = 'human-normal-estimation'
|
||||
resnet50_bert = 'resnet50-bert'
|
||||
referring_video_object_segmentation = 'swinT-referring-video-object-segmentation'
|
||||
fer = 'fer'
|
||||
@@ -480,6 +481,7 @@ class Pipelines(object):
|
||||
anydoor = 'anydoor'
|
||||
image_to_3d = 'image-to-3d'
|
||||
self_supervised_depth_completion = 'self-supervised-depth-completion'
|
||||
human_normal_estimation = 'human-normal-estimation'
|
||||
|
||||
# nlp tasks
|
||||
automatic_post_editing = 'automatic-post-editing'
|
||||
@@ -814,6 +816,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.image_normal_estimation:
|
||||
(Pipelines.image_normal_estimation,
|
||||
'Damo_XR_Lab/cv_omnidata_image-normal-estimation_normal'),
|
||||
Tasks.human_normal_estimation:
|
||||
(Pipelines.human_normal_estimation,
|
||||
'Damo_XR_Lab/cv_human_monocular-normal-estimation'),
|
||||
Tasks.indoor_layout_estimation:
|
||||
(Pipelines.indoor_layout_estimation,
|
||||
'damo/cv_panovit_indoor-layout-estimation'),
|
||||
@@ -846,9 +851,9 @@ DEFAULT_MODEL_FOR_PIPELINE = {
|
||||
Tasks.image_to_image_generation:
|
||||
(Pipelines.image_to_image_generation,
|
||||
'damo/cv_latent_diffusion_image2image_generate'),
|
||||
Tasks.image_classification:
|
||||
(Pipelines.daily_image_classification,
|
||||
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
||||
Tasks.image_classification: (
|
||||
Pipelines.daily_image_classification,
|
||||
'damo/cv_vit-base_image-classification_Dailylife-labels'),
|
||||
Tasks.image_object_detection: (
|
||||
Pipelines.image_object_detection_auto,
|
||||
'damo/cv_yolox_image-object-detection-auto'),
|
||||
|
||||
22
modelscope/models/cv/human_normal_estimation/__init__.py
Normal file
22
modelscope/models/cv/human_normal_estimation/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .human_nnet import HumanNormalEstimation
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'human_nnet': ['HumanNormalEstimation'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
80
modelscope/models/cv/human_normal_estimation/human_nnet.py
Normal file
80
modelscope/models/cv/human_normal_estimation/human_nnet.py
Normal file
@@ -0,0 +1,80 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base.base_torch_model import TorchModel
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.models.cv.human_normal_estimation.networks import config, nnet
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.utils.constant import Tasks
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.human_normal_estimation, module_name=Models.human_normal_estimation)
|
||||
class HumanNormalEstimation(TorchModel):
|
||||
|
||||
def __init__(self, model_dir: str, **kwargs):
|
||||
super().__init__(model_dir, **kwargs)
|
||||
config_file = os.path.join(model_dir, 'config.txt')
|
||||
args = config.get_args(txt_file=config_file)
|
||||
args.encoder_path = os.path.join(model_dir, args.encoder_path)
|
||||
|
||||
self.device = torch.device(
|
||||
'cuda:0') if torch.cuda.is_available() else torch.device('cpu')
|
||||
self.nnet = nnet.NormalNet(args=args).to(self.device)
|
||||
self.nnet_path = os.path.join(model_dir, 'ckpt/best_nnet.pt')
|
||||
if os.path.exists(self.nnet_path):
|
||||
ckpt = torch.load(
|
||||
self.nnet_path, map_location=self.device)['model']
|
||||
load_dict = {}
|
||||
for k, v in ckpt.items():
|
||||
if k.startswith('module.'):
|
||||
k_ = k.replace('module.', '')
|
||||
load_dict[k_] = v
|
||||
else:
|
||||
load_dict[k] = v
|
||||
self.nnet.load_state_dict(load_dict)
|
||||
self.nnet.eval()
|
||||
|
||||
self.normalize = T.Normalize(
|
||||
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
||||
|
||||
def forward(self, inputs):
|
||||
img = inputs['img'].astype(np.float32) / 255.0
|
||||
msk = inputs['msk'].astype(np.float32) / 255.0
|
||||
bbox = inputs['bbox']
|
||||
|
||||
img_h, img_w = img.shape[0:2]
|
||||
img = torch.from_numpy(img).permute(2, 0,
|
||||
1).unsqueeze(0).to(self.device)
|
||||
img = self.normalize(img)
|
||||
|
||||
fx = fy = (max(img_h, img_h) / 2.0) / np.tan(np.deg2rad(60.0 / 2.0))
|
||||
cx = (img_h / 2.0) - 0.5
|
||||
cy = (img_w / 2.0) - 0.5
|
||||
|
||||
intrins = torch.tensor(
|
||||
[[fx, 0, cx + 0.5], [0, fy, cy + 0.5], [0, 0, 1]],
|
||||
dtype=torch.float32,
|
||||
device=self.device).unsqueeze(0)
|
||||
|
||||
pred_norm = self.nnet(img, intrins=intrins)[-1]
|
||||
pred_norm = pred_norm.detach().cpu().permute(0, 2, 3, 1).numpy()
|
||||
pred_norm = pred_norm[0, ...]
|
||||
pred_norm = pred_norm * msk[..., None]
|
||||
pred_norm = pred_norm[bbox[1]:bbox[3], bbox[0]:bbox[2]]
|
||||
results = pred_norm
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs):
|
||||
normal_result = inputs
|
||||
results = {OutputKeys.NORMALS: normal_result}
|
||||
return results
|
||||
|
||||
def inference(self, data):
|
||||
results = self.forward(data)
|
||||
return results
|
||||
@@ -0,0 +1,40 @@
|
||||
import argparse
|
||||
|
||||
|
||||
def convert_arg_line_to_args(arg_line):
|
||||
for arg in arg_line.split():
|
||||
if not arg.strip():
|
||||
continue
|
||||
yield str(arg)
|
||||
|
||||
|
||||
def get_args(txt_file=None):
|
||||
parser = argparse.ArgumentParser(
|
||||
fromfile_prefix_chars='@', conflict_handler='resolve')
|
||||
parser.convert_arg_line_to_args = convert_arg_line_to_args
|
||||
|
||||
# checkpoint (only needed when testing the model)
|
||||
parser.add_argument('--ckpt_path', type=str, default=None)
|
||||
parser.add_argument('--encoder_path', type=str, default=None)
|
||||
|
||||
# ↓↓↓↓
|
||||
# NOTE: project-specific args
|
||||
parser.add_argument('--output_dim', type=int, default=3, help='{3, 4}')
|
||||
parser.add_argument('--output_type', type=str, default='R', help='{R, G}')
|
||||
parser.add_argument('--feature_dim', type=int, default=64)
|
||||
parser.add_argument('--hidden_dim', type=int, default=64)
|
||||
|
||||
parser.add_argument('--encoder_B', type=int, default=5)
|
||||
|
||||
parser.add_argument('--decoder_NF', type=int, default=2048)
|
||||
parser.add_argument('--decoder_BN', default=False, action='store_true')
|
||||
parser.add_argument('--decoder_down', type=int, default=2)
|
||||
parser.add_argument(
|
||||
'--learned_upsampling', default=False, action='store_true')
|
||||
|
||||
# read arguments from txt file
|
||||
if txt_file:
|
||||
config_filename = '@' + txt_file
|
||||
|
||||
args = parser.parse_args([config_filename])
|
||||
return args
|
||||
125
modelscope/models/cv/human_normal_estimation/networks/nnet.py
Normal file
125
modelscope/models/cv/human_normal_estimation/networks/nnet.py
Normal file
@@ -0,0 +1,125 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .submodules import (Encoder, UpSampleBN, UpSampleGN, get_pixel_coords,
|
||||
get_prediction_head, normal_activation,
|
||||
upsample_via_bilinear, upsample_via_mask)
|
||||
|
||||
PROJECT_DIR = os.path.split(os.path.dirname(os.path.realpath(__file__)))[0]
|
||||
sys.path.append(PROJECT_DIR)
|
||||
|
||||
|
||||
class NormalNet(nn.Module):
|
||||
|
||||
def __init__(self, args):
|
||||
super(NormalNet, self).__init__()
|
||||
B = args.encoder_B
|
||||
NF = args.decoder_NF
|
||||
BN = args.decoder_BN
|
||||
learned_upsampling = args.learned_upsampling
|
||||
|
||||
self.encoder = Encoder(B=B, pretrained=False, ckpt=args.encoder_path)
|
||||
self.decoder = Decoder(
|
||||
num_classes=args.output_dim,
|
||||
B=B,
|
||||
NF=NF,
|
||||
BN=BN,
|
||||
learned_upsampling=learned_upsampling)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
return self.decoder(self.encoder(x), **kwargs)
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
num_classes=3,
|
||||
B=5,
|
||||
NF=2048,
|
||||
BN=False,
|
||||
learned_upsampling=True):
|
||||
super(Decoder, self).__init__()
|
||||
input_channels = [2048, 176, 64, 40, 24]
|
||||
|
||||
UpSample = UpSampleBN if BN else UpSampleGN
|
||||
features = NF
|
||||
|
||||
self.conv2 = nn.Conv2d(
|
||||
input_channels[0] + 2,
|
||||
features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0)
|
||||
self.up1 = UpSample(
|
||||
skip_input=features // 1 + input_channels[1] + 2,
|
||||
output_features=features // 2,
|
||||
align_corners=False)
|
||||
self.up2 = UpSample(
|
||||
skip_input=features // 2 + input_channels[2] + 2,
|
||||
output_features=features // 4,
|
||||
align_corners=False)
|
||||
self.up3 = UpSample(
|
||||
skip_input=features // 4 + input_channels[3] + 2,
|
||||
output_features=features // 8,
|
||||
align_corners=False)
|
||||
self.up4 = UpSample(
|
||||
skip_input=features // 8 + input_channels[4] + 2,
|
||||
output_features=features // 16,
|
||||
align_corners=False)
|
||||
i_dim = features // 16
|
||||
|
||||
self.downsample_ratio = 2
|
||||
self.output_dim = num_classes
|
||||
|
||||
self.pred_head = get_prediction_head(i_dim + 2, 128, num_classes)
|
||||
if learned_upsampling:
|
||||
self.mask_head = get_prediction_head(
|
||||
i_dim + 2, 128,
|
||||
9 * self.downsample_ratio * self.downsample_ratio)
|
||||
self.upsample_fn = upsample_via_mask
|
||||
else:
|
||||
self.mask_head = lambda a: None
|
||||
self.upsample_fn = upsample_via_bilinear
|
||||
|
||||
self.pixel_coords = get_pixel_coords(h=1024, w=1024).to(0)
|
||||
|
||||
def ray_embedding(self, x, intrins, orig_H, orig_W):
|
||||
B, _, H, W = x.shape
|
||||
fu = intrins[:, 0, 0].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
|
||||
cu = intrins[:, 0, 2].unsqueeze(-1).unsqueeze(-1) * (W / orig_W)
|
||||
fv = intrins[:, 1, 1].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
|
||||
cv = intrins[:, 1, 2].unsqueeze(-1).unsqueeze(-1) * (H / orig_H)
|
||||
|
||||
uv = self.pixel_coords[:, :2, :H, :W].repeat(B, 1, 1, 1)
|
||||
uv[:, 0, :, :] = (uv[:, 0, :, :] - cu) / fu
|
||||
uv[:, 1, :, :] = (uv[:, 1, :, :] - cv) / fv
|
||||
return torch.cat([x, uv], dim=1)
|
||||
|
||||
def forward(self, features, intrins):
|
||||
x_block0, x_block1, x_block2, x_block3, x_block4 = features[4], features[5], features[6], \
|
||||
features[8], features[11]
|
||||
_, _, orig_H, orig_W = features[0].shape
|
||||
|
||||
x_d0 = self.conv2(
|
||||
self.ray_embedding(x_block4, intrins, orig_H, orig_W))
|
||||
x_d1 = self.up1(x_d0,
|
||||
self.ray_embedding(x_block3, intrins, orig_H, orig_W))
|
||||
x_d2 = self.up2(x_d1,
|
||||
self.ray_embedding(x_block2, intrins, orig_H, orig_W))
|
||||
x_d3 = self.up3(x_d2,
|
||||
self.ray_embedding(x_block1, intrins, orig_H, orig_W))
|
||||
x_feat = self.up4(
|
||||
x_d3, self.ray_embedding(x_block0, intrins, orig_H, orig_W))
|
||||
|
||||
out = self.pred_head(
|
||||
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
|
||||
out = normal_activation(out, elu_kappa=True)
|
||||
mask = self.mask_head(
|
||||
self.ray_embedding(x_feat, intrins, orig_H, orig_W))
|
||||
up_out = self.upsample_fn(
|
||||
out, up_mask=mask, downsample_ratio=self.downsample_ratio)
|
||||
up_out = normal_activation(up_out, elu_kappa=False)
|
||||
return [up_out]
|
||||
@@ -0,0 +1,214 @@
|
||||
import geffnet
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
INPUT_CHANNELS_DICT = {
|
||||
0: [1280, 112, 40, 24, 16],
|
||||
1: [1280, 112, 40, 24, 16],
|
||||
2: [1408, 120, 48, 24, 16],
|
||||
3: [1536, 136, 48, 32, 24],
|
||||
4: [1792, 160, 56, 32, 24],
|
||||
5: [2048, 176, 64, 40, 24],
|
||||
6: [2304, 200, 72, 40, 32],
|
||||
7: [2560, 224, 80, 48, 32]
|
||||
}
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
|
||||
def __init__(self, B=5, pretrained=True, ckpt=None):
|
||||
super(Encoder, self).__init__()
|
||||
if ckpt:
|
||||
basemodel = geffnet.create_model(
|
||||
'tf_efficientnet_b%s_ap' % B,
|
||||
pretrained=pretrained,
|
||||
checkpoint_path=ckpt)
|
||||
else:
|
||||
basemodel = geffnet.create_model(
|
||||
'tf_efficientnet_b%s_ap' % B, pretrained=pretrained)
|
||||
|
||||
basemodel.global_pool = nn.Identity()
|
||||
basemodel.classifier = nn.Identity()
|
||||
self.original_model = basemodel
|
||||
|
||||
def forward(self, x):
|
||||
features = [x]
|
||||
for k, v in self.original_model._modules.items():
|
||||
if k == 'blocks':
|
||||
for ki, vi in v._modules.items():
|
||||
features.append(vi(features[-1]))
|
||||
else:
|
||||
features.append(v(features[-1]))
|
||||
return features
|
||||
|
||||
|
||||
class ConvGRU(nn.Module):
|
||||
|
||||
def __init__(self, hidden_dim, input_dim, ks=3):
|
||||
super().__init__()
|
||||
p = (ks - 1) // 2
|
||||
self.convz = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||
self.convr = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||
self.convq = nn.Conv2d(
|
||||
hidden_dim + input_dim, hidden_dim, ks, padding=p)
|
||||
|
||||
def forward(self, h, x):
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz(hx))
|
||||
r = torch.sigmoid(self.convr(hx))
|
||||
q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1)))
|
||||
h = (1 - z) * h + z * q
|
||||
return h
|
||||
|
||||
|
||||
class UpSampleBN(nn.Module):
|
||||
|
||||
def __init__(self, skip_input, output_features, align_corners=True):
|
||||
super(UpSampleBN, self).__init__()
|
||||
self._net = nn.Sequential(
|
||||
nn.Conv2d(
|
||||
skip_input,
|
||||
output_features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU(),
|
||||
nn.Conv2d(
|
||||
output_features,
|
||||
output_features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.BatchNorm2d(output_features), nn.LeakyReLU())
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(
|
||||
x,
|
||||
size=[concat_with.size(2),
|
||||
concat_with.size(3)],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
class Conv2d_WS(nn.Conv2d):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True):
|
||||
super(Conv2d_WS,
|
||||
self).__init__(in_channels, out_channels, kernel_size, stride,
|
||||
padding, dilation, groups, bias)
|
||||
|
||||
def forward(self, x):
|
||||
weight = self.weight
|
||||
weight_mean = weight.mean(
|
||||
dim=1, keepdim=True).mean(
|
||||
dim=2, keepdim=True).mean(
|
||||
dim=3, keepdim=True)
|
||||
weight = weight - weight_mean
|
||||
std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1,
|
||||
1) + 1e-5
|
||||
weight = weight / std.expand_as(weight)
|
||||
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
||||
self.dilation, self.groups)
|
||||
|
||||
|
||||
class UpSampleGN(nn.Module):
|
||||
|
||||
def __init__(self, skip_input, output_features, align_corners=True):
|
||||
super(UpSampleGN, self).__init__()
|
||||
self._net = nn.Sequential(
|
||||
Conv2d_WS(
|
||||
skip_input,
|
||||
output_features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU(),
|
||||
Conv2d_WS(
|
||||
output_features,
|
||||
output_features,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1), nn.GroupNorm(8, output_features), nn.LeakyReLU())
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, x, concat_with):
|
||||
up_x = F.interpolate(
|
||||
x,
|
||||
size=[concat_with.size(2),
|
||||
concat_with.size(3)],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
f = torch.cat([up_x, concat_with], dim=1)
|
||||
return self._net(f)
|
||||
|
||||
|
||||
def upsample_via_bilinear(out, up_mask=None, downsample_ratio=None):
|
||||
return F.interpolate(
|
||||
out,
|
||||
scale_factor=downsample_ratio,
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
|
||||
|
||||
def upsample_via_mask(out, up_mask, downsample_ratio, padding='zero'):
|
||||
"""
|
||||
convex upsampling
|
||||
"""
|
||||
# out: low-resolution output (B, o_dim, H, W)
|
||||
# up_mask: (B, 9*k*k, H, W)
|
||||
k = downsample_ratio
|
||||
|
||||
B, C, H, W = out.shape
|
||||
up_mask = up_mask.view(B, 1, 9, k, k, H, W)
|
||||
up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W)
|
||||
|
||||
if padding == 'zero':
|
||||
up_out = F.unfold(out, [3, 3], padding=1)
|
||||
elif padding == 'replicate':
|
||||
out = F.pad(out, pad=(1, 1, 1, 1), mode='replicate')
|
||||
up_out = F.unfold(out, [3, 3], padding=0)
|
||||
else:
|
||||
raise Exception('invalid padding for convex upsampling')
|
||||
|
||||
up_out = up_out.view(B, C, 9, 1, 1, H, W)
|
||||
|
||||
up_out = torch.sum(up_mask * up_out, dim=2)
|
||||
up_out = up_out.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_out.reshape(B, C, k * H, k * W)
|
||||
|
||||
|
||||
def get_prediction_head(input_dim, hidden_dim, output_dim):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(input_dim, hidden_dim, 3, padding=1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, hidden_dim, 1), nn.ReLU(inplace=True),
|
||||
nn.Conv2d(hidden_dim, output_dim, 1))
|
||||
|
||||
|
||||
# submodules copy from DSINE
|
||||
def get_pixel_coords(h, w):
|
||||
pixel_coords = np.ones((3, h, w)).astype(np.float32)
|
||||
x_range = np.concatenate([np.arange(w).reshape(1, w)] * h, axis=0)
|
||||
y_range = np.concatenate([np.arange(h).reshape(h, 1)] * w, axis=1)
|
||||
pixel_coords[0, :, :] = x_range + 0.5
|
||||
pixel_coords[1, :, :] = y_range + 0.5
|
||||
return torch.from_numpy(pixel_coords).unsqueeze(0)
|
||||
|
||||
|
||||
def normal_activation(out, elu_kappa=True):
|
||||
normal, kappa = out[:, :3, :, :], out[:, 3:, :, :]
|
||||
normal = F.normalize(normal, p=2, dim=1)
|
||||
if elu_kappa:
|
||||
kappa = F.elu(kappa) + 1.0
|
||||
return torch.cat([normal, kappa], dim=1)
|
||||
@@ -123,6 +123,7 @@ if TYPE_CHECKING:
|
||||
from .anydoor_pipeline import AnydoorPipeline
|
||||
from .image_depth_estimation_marigold_pipeline import ImageDepthEstimationMarigoldPipeline
|
||||
from .self_supervised_depth_completion_pipeline import SelfSupervisedDepthCompletionPipeline
|
||||
from .human_normal_estimation_pipeline import HumanNormalEstimationPipeline
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
@@ -312,6 +313,7 @@ else:
|
||||
'self_supervised_depth_completion_pipeline': [
|
||||
'SelfSupervisedDepthCompletionPipeline'
|
||||
],
|
||||
'human_normal_estimation_pipeline': ['HumanNormalEstimationPipeline'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
95
modelscope/pipelines/cv/human_normal_estimation_pipeline.py
Normal file
95
modelscope/pipelines/cv/human_normal_estimation_pipeline.py
Normal file
@@ -0,0 +1,95 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import Any, Dict
|
||||
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
from modelscope.metainfo import Pipelines
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines.base import Input, Pipeline
|
||||
from modelscope.pipelines.builder import PIPELINES
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@PIPELINES.register_module(
|
||||
Tasks.human_normal_estimation,
|
||||
module_name=Pipelines.human_normal_estimation)
|
||||
class HumanNormalEstimationPipeline(Pipeline):
|
||||
r""" Human Normal Estimation Pipeline.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> from modelscope.pipelines import pipeline
|
||||
|
||||
>>> estimator = pipeline(
|
||||
>>> Tasks.human_normal_estimation, model='Damo_XR_Lab/cv_human_monocular-normal-estimation')
|
||||
>>> estimator(f"{model_dir}/tests/image_normal_estimation.jpg")
|
||||
"""
|
||||
|
||||
def __init__(self, model: str, **kwargs):
|
||||
"""
|
||||
use `model` to create a image normal estimation pipeline for prediction
|
||||
Args:
|
||||
model: model id on modelscope hub.
|
||||
"""
|
||||
super().__init__(model=model, **kwargs)
|
||||
logger.info('normal estimation model, pipeline init')
|
||||
|
||||
def preprocess(self, input: Input) -> Dict[str, Any]:
|
||||
"""
|
||||
|
||||
Args:
|
||||
input: string or ndarray or Image.Image
|
||||
|
||||
Returns:
|
||||
data: dict including inference inputs
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
img = np.array(Image.open(input))
|
||||
if isinstance(input, Image.Image):
|
||||
img = np.array(input)
|
||||
|
||||
img_h, img_w, img_ch = img.shape[0:3]
|
||||
|
||||
if img_ch == 3:
|
||||
msk = np.full((img_h, img_w, 1), 255, dtype=np.uint8)
|
||||
img = np.concatenate((img, msk), axis=-1)
|
||||
|
||||
H, W = 1024, 1024
|
||||
scale_factor = min(W / img_w, H / img_h)
|
||||
img = Image.fromarray(img)
|
||||
img = img.resize(
|
||||
(int(img_w * scale_factor), int(img_h * scale_factor)),
|
||||
Image.LANCZOS)
|
||||
|
||||
new_img = Image.new('RGBA', (W, H), color=(0, 0, 0, 0))
|
||||
paste_pos_w = (W - img.width) // 2
|
||||
paste_pos_h = (H - img.height) // 2
|
||||
new_img.paste(img, (paste_pos_w, paste_pos_h))
|
||||
|
||||
bbox = (paste_pos_w, paste_pos_h, paste_pos_w + img.width,
|
||||
paste_pos_h + img.height)
|
||||
img = np.array(new_img)
|
||||
|
||||
data = {'img': img[:, :, 0:3], 'msk': img[:, :, -1], 'bbox': bbox}
|
||||
|
||||
return data
|
||||
|
||||
def forward(self, input: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.inference(input)
|
||||
return results
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
results = self.model.postprocess(inputs)
|
||||
normals = results[OutputKeys.NORMALS]
|
||||
|
||||
normals_vis = (((normals + 1) * 0.5) * 255).astype(np.uint8)
|
||||
normals_vis = normals_vis[..., [2, 1, 0]]
|
||||
outputs = {
|
||||
OutputKeys.NORMALS: normals,
|
||||
OutputKeys.NORMALS_COLOR: normals_vis
|
||||
}
|
||||
return outputs
|
||||
@@ -78,6 +78,8 @@ class CVTasks(object):
|
||||
image_local_feature_matching = 'image-local-feature-matching'
|
||||
image_quality_assessment_degradation = 'image-quality-assessment-degradation'
|
||||
|
||||
human_normal_estimation = 'human-normal-estimation'
|
||||
|
||||
crowd_counting = 'crowd-counting'
|
||||
|
||||
# image editing
|
||||
|
||||
@@ -1179,6 +1179,13 @@
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"human-normal-estimation": {
|
||||
"input": {},
|
||||
"parameters": {},
|
||||
"output": {
|
||||
"type": "object"
|
||||
}
|
||||
},
|
||||
"image-driving-perception": {
|
||||
"input": {
|
||||
"type": "object",
|
||||
|
||||
37
tests/pipelines/test_human_normal_estimation.py
Normal file
37
tests/pipelines/test_human_normal_estimation.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import os.path
|
||||
import unittest
|
||||
|
||||
import cv2
|
||||
|
||||
from modelscope.outputs import OutputKeys
|
||||
from modelscope.pipelines import pipeline
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class HumanNormalEstimationTest(unittest.TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.task = 'human-normal-estimation'
|
||||
self.model_id = 'Damo_XR_Lab/cv_human_monocular-normal-estimation'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_image_normal_estimation(self):
|
||||
cur_dir = os.path.dirname(
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
input_location = f'{cur_dir}/data/test/images/human_normal_estimation.png'
|
||||
estimator = pipeline(
|
||||
Tasks.human_normal_estimation, model=self.model_id)
|
||||
result = estimator(input_location)
|
||||
normals_vis = result[OutputKeys.NORMALS_COLOR]
|
||||
|
||||
input_img = cv2.imread(input_location)
|
||||
normals_vis = cv2.resize(
|
||||
normals_vis, dsize=(input_img.shape[1], input_img.shape[0]))
|
||||
cv2.imwrite('result.jpg', normals_vis)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user