mirror of
https://github.com/modelscope/modelscope.git
synced 2026-05-18 05:05:00 +02:00
[to #42322933] add onnx/torchscript exporter for token classification models
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11895085
This commit is contained in:
@@ -9,3 +9,5 @@ if is_torch_available():
|
||||
SbertForSequenceClassificationExporter
|
||||
from .sbert_for_zero_shot_classification_exporter import \
|
||||
SbertForZeroShotClassificationExporter
|
||||
from .model_for_token_classification_exporter import \
|
||||
ModelForSequenceClassificationExporter
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Mapping
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.outputs import ModelOutputBase
|
||||
from modelscope.preprocessors import Preprocessor
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
|
||||
numpify_tensor_nested)
|
||||
|
||||
|
||||
@EXPORTERS.register_module(Tasks.transformer_crf, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.token_classification, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.named_entity_recognition, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.part_of_speech, module_name=Models.tcrf)
|
||||
@EXPORTERS.register_module(Tasks.word_segmentation, module_name=Models.tcrf)
|
||||
class ModelForSequenceClassificationExporter(TorchModelExporter):
|
||||
|
||||
def generate_dummy_inputs(self, **kwargs) -> Dict[str, Any]:
|
||||
"""Generate dummy inputs for model exportation to onnx or other formats by tracing.
|
||||
|
||||
Args:
|
||||
shape: A tuple of input shape which should have at most two dimensions.
|
||||
shape = (1, ) batch_size=1, sequence_length will be taken from the preprocessor.
|
||||
shape = (8, 128) batch_size=1, sequence_length=128, which will cover the config of the preprocessor.
|
||||
pair(bool, `optional`): Whether to generate sentence pairs or single sentences.
|
||||
|
||||
Returns:
|
||||
Dummy inputs.
|
||||
"""
|
||||
|
||||
assert hasattr(
|
||||
self.model, 'model_dir'
|
||||
), 'model_dir attribute is required to build the preprocessor'
|
||||
|
||||
preprocessor = Preprocessor.from_pretrained(
|
||||
self.model.model_dir, return_text=False)
|
||||
return preprocessor('2023')
|
||||
|
||||
@property
|
||||
def inputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('input_ids', dynamic_axis),
|
||||
('attention_mask', dynamic_axis),
|
||||
('offset_mapping', dynamic_axis),
|
||||
('label_mask', dynamic_axis),
|
||||
])
|
||||
|
||||
@property
|
||||
def outputs(self) -> Mapping[str, Mapping[int, str]]:
|
||||
dynamic_axis = {0: 'batch', 1: 'sequence'}
|
||||
return OrderedDict([
|
||||
('predictions', dynamic_axis),
|
||||
])
|
||||
|
||||
def _validate_onnx_model(self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
onnx_outputs,
|
||||
rtol: float = None,
|
||||
atol: float = None):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
|
||||
outputs_origin = [outputs_origin[0]
|
||||
] # keeo `predictions`, drop other outputs
|
||||
|
||||
np_dummy_inputs = numpify_tensor_nested(dummy_inputs)
|
||||
np_dummy_inputs['label_mask'] = np_dummy_inputs['label_mask'].astype(
|
||||
bool)
|
||||
outputs = ort_session.run(onnx_outputs, np_dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
print(outputs)
|
||||
print(outputs_origin)
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
@@ -213,45 +213,58 @@ class TorchModelExporter(Exporter):
|
||||
)
|
||||
|
||||
if validation:
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
numpify_tensor_nested(dummy_inputs),
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
self._validate_onnx_model(dummy_inputs, model, output,
|
||||
onnx_outputs, rtol, atol)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
def _validate_onnx_model(self,
|
||||
dummy_inputs,
|
||||
model,
|
||||
output,
|
||||
onnx_outputs,
|
||||
rtol: float = None,
|
||||
atol: float = None):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(
|
||||
*self._decide_input_format(model, dummy_inputs))
|
||||
if isinstance(outputs_origin, (Mapping, ModelOutputBase)):
|
||||
outputs_origin = list(
|
||||
numpify_tensor_nested(outputs_origin).values())
|
||||
elif isinstance(outputs_origin, (tuple, list)):
|
||||
outputs_origin = list(numpify_tensor_nested(outputs_origin))
|
||||
|
||||
outputs = ort_session.run(
|
||||
onnx_outputs,
|
||||
numpify_tensor_nested(dummy_inputs),
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
elif isinstance(outputs, tuple):
|
||||
outputs = list(outputs)
|
||||
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
print(outputs)
|
||||
print(outputs_origin)
|
||||
if not compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
def _torch_export_torch_script(self,
|
||||
model: nn.Module,
|
||||
@@ -307,28 +320,33 @@ class TorchModelExporter(Exporter):
|
||||
torch.jit.save(traced_model, output)
|
||||
|
||||
if validation:
|
||||
ts_model = torch.jit.load(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
ts_model.eval()
|
||||
outputs = ts_model.forward(*dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
if isinstance(outputs_origin, dict):
|
||||
outputs_origin = list(outputs_origin.values())
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested(
|
||||
'Torch script model output match failed', outputs,
|
||||
outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export torch script failed because of validation error.')
|
||||
self._validate_torch_script_model(dummy_inputs, model, output,
|
||||
rtol, atol)
|
||||
|
||||
def _validate_torch_script_model(self, dummy_inputs, model, output, rtol,
|
||||
atol):
|
||||
ts_model = torch.jit.load(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
ts_model.eval()
|
||||
outputs = ts_model.forward(*dummy_inputs)
|
||||
outputs = numpify_tensor_nested(outputs)
|
||||
outputs_origin = model.forward(*dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
if isinstance(outputs, dict):
|
||||
outputs = list(outputs.values())
|
||||
if isinstance(outputs_origin, dict):
|
||||
outputs_origin = list(outputs_origin.values())
|
||||
tols = {}
|
||||
if rtol is not None:
|
||||
tols['rtol'] = rtol
|
||||
if atol is not None:
|
||||
tols['atol'] = atol
|
||||
if not compare_arguments_nested(
|
||||
'Torch script model output match failed', outputs,
|
||||
outputs_origin, **tols):
|
||||
raise RuntimeError(
|
||||
'export torch script failed because of validation error.')
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -97,7 +97,7 @@ class TransformersCRFHead(TorchHead):
|
||||
mask = label_mask
|
||||
masked_lengths = mask.sum(-1).long()
|
||||
masked_logits = torch.zeros_like(logits)
|
||||
for i in range(len(mask)):
|
||||
for i in range(mask.shape[0]):
|
||||
masked_logits[
|
||||
i, :masked_lengths[i], :] = logits[i].masked_select(
|
||||
mask[i].unsqueeze(-1)).view(masked_lengths[i], -1)
|
||||
|
||||
@@ -57,16 +57,15 @@ class WordSegmentationBlankSetToLabelPreprocessor(Preprocessor):
|
||||
|
||||
class TokenClassificationPreprocessorBase(Preprocessor):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_dir: str = None,
|
||||
first_sequence: str = None,
|
||||
label: str = 'label',
|
||||
label2id: Dict = None,
|
||||
label_all_tokens: bool = False,
|
||||
mode: str = ModeKeys.INFERENCE,
|
||||
keep_original_columns: List[str] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
model_dir: str = None,
|
||||
first_sequence: str = None,
|
||||
label: str = 'label',
|
||||
label2id: Dict = None,
|
||||
label_all_tokens: bool = False,
|
||||
mode: str = ModeKeys.INFERENCE,
|
||||
keep_original_columns: List[str] = None,
|
||||
return_text: bool = True):
|
||||
"""The base class for all the token-classification tasks.
|
||||
|
||||
Args:
|
||||
@@ -82,6 +81,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
|
||||
mode: The preprocessor mode.
|
||||
keep_original_columns(List[str], `optional`): The original columns to keep,
|
||||
only available when the input is a `dict`, default None
|
||||
return_text: Whether to return `text` field in inference mode, default: True.
|
||||
"""
|
||||
super().__init__(mode)
|
||||
self.model_dir = model_dir
|
||||
@@ -90,6 +90,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
|
||||
self.label2id = label2id
|
||||
self.label_all_tokens = label_all_tokens
|
||||
self.keep_original_columns = keep_original_columns
|
||||
self.return_text = return_text
|
||||
if self.label2id is None and self.model_dir is not None:
|
||||
self.label2id = parse_label_mapping(self.model_dir)
|
||||
|
||||
@@ -164,7 +165,7 @@ class TokenClassificationPreprocessorBase(Preprocessor):
|
||||
if self.keep_original_columns and isinstance(data, dict):
|
||||
for column in self.keep_original_columns:
|
||||
outputs[column] = data[column]
|
||||
if self.mode == ModeKeys.INFERENCE:
|
||||
if self.mode == ModeKeys.INFERENCE and self.return_text:
|
||||
outputs['text'] = text
|
||||
return outputs
|
||||
|
||||
@@ -208,6 +209,7 @@ class TokenClassificationTransformersPreprocessor(
|
||||
max_length=None,
|
||||
use_fast=None,
|
||||
keep_original_columns=None,
|
||||
return_text=True,
|
||||
**kwargs):
|
||||
"""
|
||||
|
||||
@@ -218,7 +220,8 @@ class TokenClassificationTransformersPreprocessor(
|
||||
**kwargs: Extra args input into the tokenizer's __call__ method.
|
||||
"""
|
||||
super().__init__(model_dir, first_sequence, label, label2id,
|
||||
label_all_tokens, mode, keep_original_columns)
|
||||
label_all_tokens, mode, keep_original_columns,
|
||||
return_text)
|
||||
self.is_lstm_model = 'lstm' in model_dir
|
||||
model_type = None
|
||||
if self.is_lstm_model:
|
||||
|
||||
41
tests/export/test_export_token_classification.py
Normal file
41
tests/export/test_export_token_classification.py
Normal file
@@ -0,0 +1,41 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from collections import OrderedDict
|
||||
|
||||
from modelscope.exporters import Exporter, TorchModelExporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
|
||||
class TestExportTokenClassification(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
self.model_id = 'damo/nlp_raner_named-entity-recognition_chinese-base-news'
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_token_classification(self):
|
||||
model = Model.from_pretrained(self.model_id)
|
||||
with self.subTest(format='onnx'):
|
||||
print(
|
||||
Exporter.from_model(model).export_onnx(
|
||||
output_dir=self.tmp_dir))
|
||||
with self.subTest(format='torchscript'):
|
||||
print(
|
||||
Exporter.from_model(model).export_torch_script(
|
||||
output_dir=self.tmp_dir))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user