support onnx export for SCRFD model

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11826666
This commit is contained in:
yuxiang.tyx
2023-03-01 10:02:49 +08:00
committed by wenmeng.zwm
parent 8c39eefeff
commit 09a178d171
8 changed files with 580 additions and 6 deletions

View File

@@ -11,3 +11,4 @@ if is_tf_available():
if is_torch_available():
from .nlp import SbertForSequenceClassificationExporter, SbertForZeroShotClassificationExporter
from .torch_model_exporter import TorchModelExporter
from .cv import FaceDetectionSCRFDExporter

View File

@@ -1,6 +1,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from modelscope.utils.import_utils import is_tf_available
from modelscope.utils.import_utils import is_tf_available, is_torch_available
if is_tf_available():
from .cartoon_translation_exporter import CartoonTranslationExporter
if is_torch_available():
from .face_detection_scrfd_exporter import FaceDetectionSCRFDExporter

View File

@@ -0,0 +1,101 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import Mapping
import numpy as np
import onnx
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
def convert_ndarray_to_list(input_dict):
for key, value in input_dict.items():
if isinstance(value, np.ndarray):
input_dict[key] = value.tolist()
elif isinstance(value, dict):
convert_ndarray_to_list(value)
return input_dict
@EXPORTERS.register_module(Tasks.face_detection, module_name=Models.scrfd)
class FaceDetectionSCRFDExporter(TorchModelExporter):
def export_onnx(self,
output_dir: str,
opset=9,
simplify=True,
dynamic=False,
**kwargs):
"""Export the model as onnx format files.
Args:
output_dir: The output dir.
opset: The version of the ONNX operator set to use.
simplify: simplify the onnx model
dynamic: use dynamic input size
Returns:
A dict containing the model key - model file path pairs.
"""
from mmdet.core.export import preprocess_example_input
input_shape = (1, 3, 640, 640)
input_config = {
'input_shape': input_shape,
'input_path': 'data/test/images/face_detection2.jpeg',
'normalize_cfg': {
'mean': [127.5, 127.5, 127.5],
'std': [128.0, 128.0, 128.0]
}
}
model = self.model.detector.module if 'model' not in kwargs else kwargs.pop(
'model')
model = model.cpu().eval()
output_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
if simplify or dynamic:
ori_output_file = output_file.split('.onnx')[0] + '_ori.onnx'
else:
ori_output_file = output_file
one_img, one_meta = preprocess_example_input(input_config)
tensor_data = [one_img]
if 'show_img' in one_meta:
del one_meta['show_img']
one_meta = convert_ndarray_to_list(one_meta)
model.forward = partial(
model.forward, img_metas=[[one_meta]], return_loss=False)
torch.onnx.export(
model,
tensor_data,
ori_output_file,
keep_initializers_as_inputs=False,
verbose=False,
opset_version=opset)
if simplify or dynamic:
model = onnx.load(ori_output_file)
if dynamic:
model.graph.input[0].type.tensor_type.shape.dim[
2].dim_param = '?'
model.graph.input[0].type.tensor_type.shape.dim[
3].dim_param = '?'
if simplify:
from onnxsim import simplify
if dynamic:
input_shapes = {
model.graph.input[0].name: list(input_shape)
}
model, check = simplify(
model, overwrite_input_shapes=input_shapes)
else:
model, check = simplify(model)
assert check, 'Simplified ONNX model could not be validated'
onnx.save(model, output_file)
os.remove(ori_output_file)
print(f'Successfully exported ONNX model: {output_file}')
return {'model': output_file}

View File

@@ -0,0 +1,271 @@
"""
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at
https://github.com/deepinsight/insightface/blob/master/detection/scrfd/mmdet/models/detectors/base.py
"""
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
import mmcv
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
from mmcv.runner import auto_fp16
from mmcv.utils import print_log
from mmdet.utils import get_root_logger
class BaseDetector(nn.Module, metaclass=ABCMeta):
"""Base class for detectors."""
def __init__(self):
super(BaseDetector, self).__init__()
self.fp16_enabled = False
@property
def with_neck(self):
"""bool: whether the detector has a neck"""
return hasattr(self, 'neck') and self.neck is not None
# TODO: these properties need to be carefully handled
# for both single stage & two stage detectors
@property
def with_shared_head(self):
"""bool: whether the detector has a shared head in the RoI Head"""
return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
@property
def with_bbox(self):
"""bool: whether the detector has a bbox head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
@property
def with_mask(self):
"""bool: whether the detector has a mask head"""
return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
or (hasattr(self, 'mask_head') and self.mask_head is not None))
@abstractmethod
def extract_feat(self, imgs):
"""Extract features from images."""
pass
def extract_feats(self, imgs):
"""Extract features from multiple images.
Args:
imgs (list[torch.Tensor]): A list of images. The images are
augmented from the same image but in different ways.
Returns:
list[torch.Tensor]: Features of different images
"""
assert isinstance(imgs, list)
return [self.extract_feat(img) for img in imgs]
def forward_train(self, imgs, img_metas, **kwargs):
"""
Args:
img (list[Tensor]): List of tensors of shape (1, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
kwargs (keyword arguments): Specific to concrete implementation.
"""
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
batch_input_shape = tuple(imgs[0].size()[-2:])
for img_meta in img_metas:
img_meta['batch_input_shape'] = batch_input_shape
async def async_simple_test(self, img, img_metas, **kwargs):
raise NotImplementedError
@abstractmethod
def simple_test(self, img, img_metas, **kwargs):
pass
@abstractmethod
def aug_test(self, imgs, img_metas, **kwargs):
"""Test function with test time augmentation."""
pass
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
if pretrained is not None:
logger = get_root_logger()
print_log(f'load model from: {pretrained}', logger=logger)
async def aforward_test(self, *, img, img_metas, **kwargs):
for var, name in [(img, 'img'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
num_augs = len(img)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(img)}) '
f'!= num of image metas ({len(img_metas)})')
# TODO: remove the restriction of samples_per_gpu == 1 when prepared
samples_per_gpu = img[0].size(0)
assert samples_per_gpu == 1
if num_augs == 1:
return await self.async_simple_test(img[0], img_metas[0], **kwargs)
else:
raise NotImplementedError
def forward_test(self, imgs, img_metas, **kwargs):
"""
Args:
imgs (List[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (List[List[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch.
"""
for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
if not isinstance(var, list):
raise TypeError(f'{name} must be a list, but got {type(var)}')
imgs = [imgs[0]]
num_augs = len(imgs)
if num_augs != len(img_metas):
raise ValueError(f'num of augmentations ({len(imgs)}) '
f'!= num of image meta ({len(img_metas)})')
# NOTE the batched image size information may be useful, e.g.
# in DETR, this is needed for the construction of masks, which is
# then used for the transformer_head.
for img, img_meta in zip(imgs, img_metas):
batch_size = len(img_meta)
for img_id in range(batch_size):
img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
if num_augs == 1:
# proposals (List[List[Tensor]]): the outer list indicates
# test-time augs (multiscale, flip, etc.) and the inner list
# indicates images in a batch.
# The Tensor should have a shape Px4, where P is the number of
# proposals.
if 'proposals' in kwargs:
kwargs['proposals'] = kwargs['proposals'][0]
return self.simple_test(imgs[0], img_metas[0], **kwargs)
else:
assert imgs[0].size(0) == 1, 'aug test does not support ' \
'inference with batch size ' \
f'{imgs[0].size(0)}'
# TODO: support test augmentation for predefined proposals
assert 'proposals' not in kwargs
return self.aug_test(imgs, img_metas, **kwargs)
@auto_fp16(apply_to=('img', ))
def forward(self, img, img_metas, return_loss=True, **kwargs):
"""Calls either :func:`forward_train` or :func:`forward_test` depending
on whether ``return_loss`` is ``True``.
Note this setting will change the expected inputs. When
``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
and List[dict]), and when ``resturn_loss=False``, img and img_meta
should be double nested (i.e. List[Tensor], List[List[dict]]), with
the outer list indicating test time augmentations.
"""
if return_loss:
return self.forward_train(img, img_metas, **kwargs)
else:
return self.forward_test(img, img_metas, **kwargs)
def _parse_losses(self, losses):
"""Parse the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary infomation.
Returns:
tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
which may be a weighted sum of all losses, log_vars contains \
all the variables to be sent to the logger.
"""
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items()
if 'loss' in _key)
log_vars['loss'] = loss
for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()
return loss, log_vars
def train_step(self, data, optimizer):
"""The iteration step during training.
This method defines an iteration step during training, except for the
back propagation and optimizer updating, which are done in an optimizer
hook. Note that in some complicated cases or models, the whole process
including back propagation and optimizer updating is also defined in
this method, such as GAN.
Args:
data (dict): The output of dataloader.
optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
runner is passed to ``train_step()``. This argument is unused
and reserved.
Returns:
dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
``num_samples``.
- ``loss`` is a tensor for back propagation, which can be a \
weighted sum of multiple losses.
- ``log_vars`` contains all the variables to be sent to the
logger.
- ``num_samples`` indicates the batch size (when the model is \
DDP, it means the batch size on each GPU), which is used for \
averaging the logs.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs
def val_step(self, data, optimizer):
"""The iteration step during validation.
This method shares the same signature as :func:`train_step`, but used
during val epochs. Note that the evaluation after training epochs is
not implemented with this method, but an evaluation hook.
"""
losses = self(**data)
loss, log_vars = self._parse_losses(losses)
outputs = dict(
loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
return outputs

View File

@@ -4,13 +4,14 @@ https://github.com/deepinsight/insightface/tree/master/detection/scrfd/mmdet/mod
"""
import torch
from mmdet.models.builder import DETECTORS
from mmdet.models.detectors.single_stage import SingleStageDetector
from ....mmdet_patch.core.bbox import bbox2result
from ....mmdet_patch.models.detectors.single_stage import \
CustomSingleStageDetector
@DETECTORS.register_module()
class SCRFD(SingleStageDetector):
class SCRFD(CustomSingleStageDetector):
def __init__(self,
backbone,
@@ -47,7 +48,7 @@ class SCRFD(SingleStageDetector):
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(SingleStageDetector, self).forward_train(img, img_metas)
super(CustomSingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_keypointss,
@@ -60,7 +61,7 @@ class SCRFD(SingleStageDetector):
rescale=False,
repeat_head=1,
output_kps_var=0,
output_results=1):
output_results=2):
"""Test function without test time augmentation.
Args:

View File

@@ -0,0 +1,163 @@
"""
The implementation here is modified based on insightface, originally MIT license and publicly avaialbe at
https://github.com/deepinsight/insightface/blob/master/detection/scrfd/mmdet/models/detectors/single_stage.py
"""
import torch
import torch.nn as nn
from mmdet.models.builder import (DETECTORS, build_backbone, build_head,
build_neck)
from ....mmdet_patch.core.bbox import bbox2result
from .base import BaseDetector
@DETECTORS.register_module()
class CustomSingleStageDetector(BaseDetector):
"""Base class for single-stage detectors.
Single-stage detectors directly and densely predict bounding boxes on the
output features of the backbone+neck.
"""
def __init__(self,
backbone,
neck=None,
bbox_head=None,
train_cfg=None,
test_cfg=None,
pretrained=None):
super(CustomSingleStageDetector, self).__init__()
self.backbone = build_backbone(backbone)
if neck is not None:
self.neck = build_neck(neck)
bbox_head.update(train_cfg=train_cfg)
bbox_head.update(test_cfg=test_cfg)
self.bbox_head = build_head(bbox_head)
self.train_cfg = train_cfg
self.test_cfg = test_cfg
self.init_weights(pretrained=pretrained)
def init_weights(self, pretrained=None):
"""Initialize the weights in detector.
Args:
pretrained (str, optional): Path to pre-trained weights.
Defaults to None.
"""
super(CustomSingleStageDetector, self).init_weights(pretrained)
self.backbone.init_weights(pretrained=pretrained)
if self.with_neck:
if isinstance(self.neck, nn.Sequential):
for m in self.neck:
m.init_weights()
else:
self.neck.init_weights()
self.bbox_head.init_weights()
def extract_feat(self, img):
"""Directly extract features from the backbone+neck."""
x = self.backbone(img)
if self.with_neck:
x = self.neck(x)
return x
def forward_dummy(self, img):
"""Used for computing network flops.
See `mmdetection/tools/get_flops.py`
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
return outs
def forward_train(self,
img,
img_metas,
gt_bboxes,
gt_labels,
gt_bboxes_ignore=None):
"""
Args:
img (Tensor): Input images of shape (N, C, H, W).
Typically these should be mean centered and std scaled.
img_metas (list[dict]): A List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys see
:class:`mmdet.datasets.pipelines.Collect`.
gt_bboxes (list[Tensor]): Each item are the truth boxes for each
image in [tl_x, tl_y, br_x, br_y] format.
gt_labels (list[Tensor]): Class indices corresponding to each box
gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
boxes can be ignored when computing the loss.
Returns:
dict[str, Tensor]: A dictionary of loss components.
"""
super(CustomSingleStageDetector, self).forward_train(img, img_metas)
x = self.extract_feat(img)
losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
gt_labels, gt_bboxes_ignore)
return losses
def simple_test(self, img, img_metas, rescale=False):
"""Test function without test time augmentation.
Args:
imgs (list[torch.Tensor]): List of multiple images
img_metas (list[dict]): List of image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
x = self.extract_feat(img)
outs = self.bbox_head(x)
if torch.onnx.is_in_onnx_export():
print('single_stage.py in-onnx-export')
print(outs.__class__)
cls_score, bbox_pred = outs
for c in cls_score:
print(c.shape)
for c in bbox_pred:
print(c.shape)
return outs
bbox_list = self.bbox_head.get_bboxes(
*outs, img_metas, rescale=rescale)
# skip post-processing when exporting to ONNX
if torch.onnx.is_in_onnx_export():
return bbox_list
bbox_results = [
bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
for det_bboxes, det_labels in bbox_list
]
return bbox_results
def aug_test(self, imgs, img_metas, rescale=False):
"""Test function with test time augmentation.
Args:
imgs (list[Tensor]): the outer list indicates test-time
augmentations and inner Tensor should have a shape NxCxHxW,
which contains all images in the batch.
img_metas (list[list[dict]]): the outer list indicates test-time
augs (multiscale, flip, etc.) and the inner list indicates
images in a batch. each dict has image information.
rescale (bool, optional): Whether to rescale the results.
Defaults to False.
Returns:
list[list[np.ndarray]]: BBox results of each image and classes.
The outer list corresponds to each image. The inner list
corresponds to each class.
"""
assert hasattr(self.bbox_head, 'aug_test'), \
f'{self.bbox_head.__class__.__name__}' \
' does not support test-time augmentation'
print('aug-test:', len(imgs))
feats = self.extract_feats(imgs)
return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]

View File

@@ -33,6 +33,8 @@ nerfacc==0.2.2
networkx
numba
omegaconf
onnx
onnx-simplifier
onnxruntime>=1.10
open-clip-torch>=2.7.0
opencv-python

View File

@@ -0,0 +1,34 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
'''
For inference using onnx model, please refer to:
https://github.com/deepinsight/insightface/blob/master/detection/scrfd/tools/scrfd.py
'''
import os
import shutil
import tempfile
import unittest
from collections import OrderedDict
from modelscope.exporters import Exporter
from modelscope.models import Model
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestExportFaceDetectionSCRFD(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/cv_resnet_facedetection_scrfd10gkps'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_face_detection_scrfd(self):
model = Model.from_pretrained(self.model_id)
print(Exporter.from_model(model).export_onnx(output_dir=self.tmp_dir))
if __name__ == '__main__':
unittest.main()