mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
add onnx exporter for ocr_detection db model
支持ocr_detection db pytorch模型转onnx Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/14117993 * add onnx exporter for ocr_detection db model * add code for onnx convert * fix bug
This commit is contained in:
41
modelscope/exporters/cv/ocr_detection_db_exporter.py
Normal file
41
modelscope/exporters/cv/ocr_detection_db_exporter.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Mapping
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.ocr_detection, module_name=Models.ocr_detection)
|
||||
class OCRDetectionDBExporter(TorchModelExporter):
|
||||
|
||||
def export_onnx(self,
|
||||
output_dir: str,
|
||||
opset=11,
|
||||
input_shape=(1, 3, 800, 800)):
|
||||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
|
||||
dummy_input = torch.randn(*input_shape)
|
||||
self.model.onnx_export = True
|
||||
self.model.eval()
|
||||
_ = self.model(dummy_input)
|
||||
torch.onnx._export(
|
||||
self.model,
|
||||
dummy_input,
|
||||
onnx_file,
|
||||
input_names=[
|
||||
'images',
|
||||
],
|
||||
output_names=[
|
||||
'pred',
|
||||
],
|
||||
opset_version=opset)
|
||||
|
||||
return {'model', onnx_file}
|
||||
@@ -36,6 +36,7 @@ class OCRDetection(TorchModel):
|
||||
self.return_polygon = cfgs.model.inference_kwargs.return_polygon
|
||||
self.backbone = cfgs.model.backbone
|
||||
self.detector = None
|
||||
self.onnx_export = False
|
||||
if self.backbone == 'resnet50':
|
||||
self.detector = VLPTModel()
|
||||
elif self.backbone == 'resnet18':
|
||||
@@ -62,11 +63,20 @@ class OCRDetection(TorchModel):
|
||||
org_shape (`List`): image original shape,
|
||||
value is [height, width].
|
||||
"""
|
||||
pred = self.detector(input['img'])
|
||||
if type(input) is dict:
|
||||
pred = self.detector(input['img'])
|
||||
else:
|
||||
# for onnx convert
|
||||
input = {'img': input, 'org_shape': [800, 800]}
|
||||
pred = self.detector(input['img'])
|
||||
return {'results': pred, 'org_shape': input['org_shape']}
|
||||
|
||||
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
|
||||
pred = inputs['results'][0]
|
||||
|
||||
if self.onnx_export:
|
||||
return pred
|
||||
|
||||
height, width = inputs['org_shape']
|
||||
segmentation = pred > self.thresh
|
||||
if self.return_polygon:
|
||||
|
||||
@@ -164,15 +164,17 @@ def polygons_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
return boxes, scores
|
||||
|
||||
|
||||
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height):
|
||||
def boxes_from_bitmap(pred, _bitmap, dest_width, dest_height, is_numpy=False):
|
||||
"""
|
||||
_bitmap: single map with shape (1, H, W),
|
||||
whose values are binarized as {0, 1}
|
||||
"""
|
||||
|
||||
assert _bitmap.size(0) == 1
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
if is_numpy:
|
||||
bitmap = _bitmap[0]
|
||||
pred = pred[0]
|
||||
else:
|
||||
bitmap = _bitmap.cpu().numpy()[0]
|
||||
pred = pred.cpu().detach().numpy()[0]
|
||||
height, width = bitmap.shape
|
||||
boxes = []
|
||||
scores = []
|
||||
|
||||
32
tests/export/test_export_ocr_detection_db.py
Normal file
32
tests/export/test_export_ocr_detection_db.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
from modelscope.exporters import Exporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestExportOCRDetectionDB(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'damo/cv_resnet18_ocr-detection-db-line-level_damo'
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_ocr_detection_db(self):
|
||||
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
Exporter.from_model(model).export_onnx(
|
||||
input_shape=(1, 3, 800, 800), output_dir=self.tmp_dir)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user