[to #42322933] add onnx/torchscript exporter for token classification models

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11895085
This commit is contained in:
pangda
2023-03-07 14:12:23 +08:00
committed by yuze.zyz
parent 99fa2fe909
commit 798aa93cba
6 changed files with 251 additions and 73 deletions

View File

@@ -9,3 +9,5 @@ if is_torch_available():
SbertForSequenceClassificationExporter
from .sbert_for_zero_shot_classification_exporter import \
SbertForZeroShotClassificationExporter
from .model_for_token_classification_exporter import \
ModelForSequenceClassificationExporter

View File

@@ -0,0 +1,114 @@
from collections import OrderedDict
from typing import Any, Dict, Mapping
import torch
from torch import nn
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.outputs import ModelOutputBase
from modelscope.preprocessors import Preprocessor
from modelscope.utils.constant import Tasks
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
numpify_tensor_nested)
@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf)
@EXPORTERS.register_module(
Tasks.named_entity_recognition, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf)
@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf)
class ModelForSequenceClassificationExporter(TorchModelExporter):
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
Args:
shape: A tuple of input shape which should have at most two dimensions.
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
Returns:
Dummy inputs.
"""
assert hasattr(
self.model, 'model_dir'
), 'model_dir attribute is required to build the preprocessor'
preprocessor = Preprocessor.from_pretrained(
self.model.model_dir, return_text=False)
return preprocessor('2023')
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('input_ids', dynamic_axis),
('attention_mask', dynamic_axis),
('offset_mapping', dynamic_axis),
('label_mask', dynamic_axis),
])
@property
def outputs(self) -> Mapping[str, Mapping[int, str]]:
dynamic_axis = {0: 'batch', 1: 'sequence'}
return OrderedDict([
('predictions', dynamic_axis),
])
def _validate_onnx_model(self,
dummy_inputs,
model,
output,
onnx_outputs,
rtol: float = None,
atol: float = None):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs_origin = [outputs_origin[0]
] # keeo `predictions`, drop other outputs
np_dummy_inputs = numpify_tensor_nested(dummy_inputs)
np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype(
bool)
outputs = ort_session.run(onnx_outputs, np_dummy_inputs)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
print(outputs)
print(outputs_origin)
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')

View File

@@ -213,45 +213,58 @@ class TorchModelExporter(Exporter):
)
if validation:
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs = ort_session.run(
onnx_outputs,
numpify_tensor_nested(dummy_inputs),
)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
self._validate_onnx_model(dummy_inputs, model, output,
onnx_outputs, rtol, atol)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
def _validate_onnx_model(self,
dummy_inputs,
model,
output,
onnx_outputs,
rtol: float = None,
atol: float = None):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(
*self._decide_input_format(model, dummy_inputs))
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
outputs_origin = list(
numpify_tensor_nested(outputs_origin).values())
elif isinstance(outputs_origin, (tuple, list)):
outputs_origin = list(numpify_tensor_nested(outputs_origin))
outputs = ort_session.run(
onnx_outputs,
numpify_tensor_nested(dummy_inputs),
)
outputs = numpify_tensor_nested(outputs)
if isinstance(outputs, dict):
outputs = list(outputs.values())
elif isinstance(outputs, tuple):
outputs = list(outputs)
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
print(outputs)
print(outputs_origin)
if not compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin, **tols):
raise RuntimeError(
'export onnx failed because of validation error.')
def _torch_export_torch_script(self,
model: nn.Module,
@@ -307,28 +320,33 @@ class TorchModelExporter(Exporter):
torch.jit.save(traced_model, output)
if validation:
ts_model = torch.jit.load(output)
with torch.no_grad():
model.eval()
ts_model.eval()
outputs = ts_model.forward(*dummy_inputs)
outputs = numpify_tensor_nested(outputs)
outputs_origin = model.forward(*dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
if isinstance(outputs, dict):
outputs = list(outputs.values())
if isinstance(outputs_origin, dict):
outputs_origin = list(outputs_origin.values())
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested(
'Torch script model output match failed', outputs,
outputs_origin, **tols):
raise RuntimeError(
'export torch script failed because of validation error.')
self._validate_torch_script_model(dummy_inputs, model, output,
rtol, atol)
def _validate_torch_script_model(self, dummy_inputs, model, output, rtol,
atol):
ts_model = torch.jit.load(output)
with torch.no_grad():
model.eval()
ts_model.eval()
outputs = ts_model.forward(*dummy_inputs)
outputs = numpify_tensor_nested(outputs)
outputs_origin = model.forward(*dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
if isinstance(outputs, dict):
outputs = list(outputs.values())
if isinstance(outputs_origin, dict):
outputs_origin = list(outputs_origin.values())
tols = {}
if rtol is not None:
tols['rtol'] = rtol
if atol is not None:
tols['atol'] = atol
if not compare_arguments_nested(
'Torch script model output match failed', outputs,
outputs_origin, **tols):
raise RuntimeError(
'export torch script failed because of validation error.')
@contextmanager

View File

@@ -97,7 +97,7 @@ class TransformersCRFHead(TorchHead):
mask = label_mask
masked_lengths = mask.sum(-1).long()
masked_logits = torch.zeros_like(logits)
for i in range(len(mask)):
for i in range(mask.shape[0]):
masked_logits[
i, :masked_lengths[i], :] = logits[i].masked_select(
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)

View File

@@ -57,16 +57,15 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
class TokenClassificationPreprocessorBase(Preprocessor):
def __init__(
self,
model_dir: str = None,
first_sequence: str = None,
label: str = 'label',
label2id: Dict = None,
label_all_tokens: bool = False,
mode: str = ModeKeys.INFERENCE,
keep_original_columns: List[str] = None,
):
def __init__(self,
model_dir: str = None,
first_sequence: str = None,
label: str = 'label',
label2id: Dict = None,
label_all_tokens: bool = False,
mode: str = ModeKeys.INFERENCE,
keep_original_columns: List[str] = None,
return_text: bool = True):
"""The base class for all the token-classification tasks.
Args:
@@ -82,6 +81,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
mode: The preprocessor mode.
keep_original_columns(List[str], `optional`): The original columns to keep,
only available when the input is a `dict`, default None
return_text: Whether to return `text` field in inference mode, default: True.
"""
super().__init__(mode)
self.model_dir = model_dir
@@ -90,6 +90,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
self.label2id = label2id
self.label_all_tokens = label_all_tokens
self.keep_original_columns = keep_original_columns
self.return_text = return_text
if self.label2id is None and self.model_dir is not None:
self.label2id = parse_label_mapping(self.model_dir)
@@ -164,7 +165,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
if self.keep_original_columns and isinstance(data, dict):
for column in self.keep_original_columns:
outputs[column] = data[column]
if self.mode == ModeKeys.INFERENCE:
if self.mode == ModeKeys.INFERENCE and self.return_text:
outputs['text'] = text
return outputs
@@ -208,6 +209,7 @@ class TokenClassificationTransformersPreprocessor(
max_length=None,
use_fast=None,
keep_original_columns=None,
return_text=True,
**kwargs):
"""
@@ -218,7 +220,8 @@ class TokenClassificationTransformersPreprocessor(
**kwargs: Extra args input into the tokenizer's __call__ method.
"""
super().__init__(model_dir, first_sequence, label, label2id,
label_all_tokens, mode, keep_original_columns)
label_all_tokens, mode, keep_original_columns,
return_text)
self.is_lstm_model = 'lstm' in model_dir
model_type = None
if self.is_lstm_model:

View File

@@ -0,0 +1,41 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from collections import OrderedDict
from modelscope.exporters import Exporter, TorchModelExporter
from modelscope.models import Model
from modelscope.utils.constant import Tasks
from modelscope.utils.test_utils import test_level
class TestExportTokenClassification(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_token_classification(self):
model = Model.from_pretrained(self.model_id)
with self.subTest(format='onnx'):
print(
Exporter.from_model(model).export_onnx(
output_dir=self.tmp_dir))
with self.subTest(format='torchscript'):
print(
Exporter.from_model(model).export_torch_script(
output_dir=self.tmp_dir))
if __name__ == '__main__':
unittest.main()