[to #42322933] add ONNX exporter for ans dfsmn

This commit is contained in:
bin.xue
2023-04-28 10:33:32 +08:00
parent 93f73a26e7
commit 6d68f0ea64
4 changed files with 168 additions and 1 deletions

View File

@@ -0,0 +1,22 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .ans_dfsmn_exporter import ANSDFSMNExporter
else:
_import_structure = {
'ans_dfsmn_exporter': ['ANSDFSMNExporter'],
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,62 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import torch
from modelscope.exporters.builder import EXPORTERS
from modelscope.exporters.torch_model_exporter import TorchModelExporter
from modelscope.metainfo import Models
from modelscope.utils.constant import ModelFile, Tasks
INPUT_NAME = 'input'
OUTPUT_NAME = 'output'
@EXPORTERS.register_module(
Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
class ANSDFSMNExporter(TorchModelExporter):
def export_onnx(self, output_dir: str, opset=9, **kwargs):
"""Export the model as onnx format files.
Args:
output_dir: The output dir.
opset: The version of the ONNX operator set to use.
kwargs:
device: The device used to forward.
Returns:
A dict containing the model key - model file path pairs.
"""
model = self.model if 'model' not in kwargs else kwargs.pop('model')
device_name = 'cpu' if 'device' not in kwargs else kwargs.pop('device')
model_bin_file = os.path.join(model.model_dir,
ModelFile.TORCH_MODEL_BIN_FILE)
if os.path.exists(model_bin_file):
checkpoint = torch.load(model_bin_file, map_location='cpu')
model.load_state_dict(checkpoint)
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
with torch.no_grad():
model.eval()
device = torch.device(device_name)
model.to(device)
model_script = torch.jit.script(model)
fbank_input = torch.zeros((1, 3, 120), dtype=torch.float32)
torch.onnx.export(
model_script,
fbank_input,
onnx_file,
opset_version=opset,
input_names=[INPUT_NAME],
output_names=[OUTPUT_NAME],
dynamic_axes={
INPUT_NAME: {
0: 'batch_size',
1: 'number_of_frame'
},
OUTPUT_NAME: {
0: 'batch_size',
1: 'number_of_frame'
}
})
return {'model': onnx_file}

View File

@@ -0,0 +1,83 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import pickle
import shutil
import tempfile
import unittest
import torch
from modelscope.exporters import Exporter
from modelscope.models import Model
from modelscope.utils.logger import get_logger
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
numpify_tensor_nested)
from modelscope.utils.test_utils import test_level
INPUT_PKL = 'data/test/audios/input.pkl'
INPUT_NAME = 'input'
OUTPUT_NAME = 'output'
logger = get_logger()
class ExportSpeechSignalProcessTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
def tearDown(self):
shutil.rmtree(self.tmp_dir)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_export_ans_dfsmn(self):
model_id = 'damo/speech_dfsmn_ans_psm_48k_causal'
model = Model.from_pretrained(model_id)
onnx_info = Exporter.from_model(model).export_onnx(
output_dir=self.tmp_dir)
with open(os.path.join(os.getcwd(), INPUT_PKL), 'rb') as f:
fbank_input = pickle.load(f).cpu()
self.assertTrue(
self._validate_onnx_model(fbank_input, model, onnx_info['model']),
'export onnx failed because of validation error.')
@staticmethod
def _validate_onnx_model(dummy_inputs, model, output):
try:
import onnx
import onnxruntime as ort
except ImportError:
logger.warning(
'Cannot validate the exported onnx file, because '
'the installation of onnx or onnxruntime cannot be found')
return
onnx_model = onnx.load(output)
onnx.checker.check_model(onnx_model)
ort_session = ort.InferenceSession(output)
with torch.no_grad():
model.eval()
outputs_origin = model.forward(dummy_inputs)
outputs_origin = numpify_tensor_nested(outputs_origin)
input_feed = {INPUT_NAME: dummy_inputs.numpy()}
outputs = ort_session.run(
None,
input_feed,
)
outputs = numpify_tensor_nested(outputs[0])
print(outputs)
print(outputs_origin)
return compare_arguments_nested('Onnx model output match failed',
outputs, outputs_origin)
if __name__ == '__main__':
unittest.main()