add onnx exporter for ocr_detection db model

支持ocr_detection db pytorch模型转onnx
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14117993
* add onnx exporter for ocr_detection db model

* add code for onnx convert

* fix bug
This commit is contained in:
xixing.tj
2023-09-25 11:34:28 +08:00
committed by wenmeng.zwm
parent 04b24814ca
commit e7e712c5c2
4 changed files with 91 additions and 6 deletions

View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from functools import partial
from typing import Mapping
import numpy as np
import onnx
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
@EXPORTERS.register_module(
Tasks.ocr_detection, module_name=Models.ocr_detection)
class OCRDetectionDBExporter(TorchModelExporter):
def export_onnx(self,
output_dir: str,
opset=11,
input_shape=(1, 3, 800, 800)):
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
dummy_input = torch.randn(*input_shape)
self.model.onnx_export = True
self.model.eval()
_ = self.model(dummy_input)
torch.onnx._export(
self.model,
dummy_input,
onnx_file,
input_names=[
'images',
],
output_names=[
'pred',
],
opset_version=opset)
return {'model', onnx_file}

View File

@@ -36,6 +36,7 @@ class OCRDetection(TorchModel):
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
self.backbone = cfgs.model.backbone
self.detector = None
self.onnx_export = False
if self.backbone == 'resnet50':
self.detector = VLPTModel()
elif self.backbone == 'resnet18':
@@ -62,11 +63,20 @@ class OCRDetection(TorchModel):
org_shape (`List`): image original shape,
value is [height, width].
"""
pred = self.detector(input['img'])
if type(input) is dict:
pred = self.detector(input['img'])
else:
# for onnx convert
input = {'img': input, 'org_shape': [800, 800]}
pred = self.detector(input['img'])
return {'results': pred, 'org_shape': input['org_shape']}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
pred = inputs['results'][0]
if self.onnx_export:
return pred
height, width = inputs['org_shape']
segmentation = pred > self.thresh
if self.return_polygon:

View File

@@ -164,15 +164,17 @@ def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
return boxes, scores
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height, is_numpy=False):
"""
_bitmap: single map with shape (1, H, W),
whose values are binarized as {0, 1}
"""
assert _bitmap.size(0) == 1
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
if is_numpy:
bitmap = _bitmap[0]
pred = pred[0]
else:
bitmap = _bitmap.cpu().numpy()[0]
pred = pred.cpu().detach().numpy()[0]
height, width = bitmap.shape
boxes = []
scores = []

View File

@@ -0,0 +1,32 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from collections import OrderedDict
from modelscope.exporters import Exporter
from modelscope.models import Model
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestExportOCRDetectionDB(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo'
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_ocr_detection_db(self):
model = Model.from_pretrained(self.model_id)
Exporter.from_model(model).export_onnx(
input_shape=(1, 3, 800, 800), output_dir=self.tmp_dir)
if __name__ == '__main__':
unittest.main()