From 6d68f0ea64fbbc27d3043d47db9bfbc71bc0028d Mon Sep 17 00:00:00 2001 From: "bin.xue" Date: Fri, 28 Apr 2023 10:33:32 +0800 Subject: [PATCH] [to #42322933] add ONNX exporter for ans dfsmn --- data/test | 2 +- modelscope/exporters/audio/__init__.py | 22 +++++ .../exporters/audio/ans_dfsmn_exporter.py | 62 ++++++++++++++ .../test_export_speech_signal_process.py | 83 +++++++++++++++++++ 4 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 modelscope/exporters/audio/__init__.py create mode 100644 modelscope/exporters/audio/ans_dfsmn_exporter.py create mode 100644 tests/export/test_export_speech_signal_process.py diff --git a/data/test b/data/test index c117008c..0a61e00d 160000 --- a/data/test +++ b/data/test @@ -1 +1 @@ -Subproject commit c117008caa9dc447c208e9ed6bc11310512d4a3a +Subproject commit 0a61e00de4a4b529099b357cbb0b2af83ac2f31e diff --git a/modelscope/exporters/audio/__init__.py b/modelscope/exporters/audio/__init__.py new file mode 100644 index 00000000..883151cd --- /dev/null +++ b/modelscope/exporters/audio/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +from typing import TYPE_CHECKING + +from modelscope.utils.import_utils import LazyImportModule + +if TYPE_CHECKING: + from .ans_dfsmn_exporter import ANSDFSMNExporter +else: + _import_structure = { + 'ans_dfsmn_exporter': ['ANSDFSMNExporter'], + } + + import sys + + sys.modules[__name__] = LazyImportModule( + __name__, + globals()['__file__'], + _import_structure, + module_spec=__spec__, + extra_objects={}, + ) diff --git a/modelscope/exporters/audio/ans_dfsmn_exporter.py b/modelscope/exporters/audio/ans_dfsmn_exporter.py new file mode 100644 index 00000000..976f983f --- /dev/null +++ b/modelscope/exporters/audio/ans_dfsmn_exporter.py @@ -0,0 +1,62 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import os + +import torch + +from modelscope.exporters.builder import EXPORTERS +from modelscope.exporters.torch_model_exporter import TorchModelExporter +from modelscope.metainfo import Models +from modelscope.utils.constant import ModelFile, Tasks + +INPUT_NAME = 'input' +OUTPUT_NAME = 'output' + + +@EXPORTERS.register_module( + Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans) +class ANSDFSMNExporter(TorchModelExporter): + + def export_onnx(self, output_dir: str, opset=9, **kwargs): + """Export the model as onnx format files. + + Args: + output_dir: The output dir. + opset: The version of the ONNX operator set to use. + kwargs: + device: The device used to forward. + Returns: + A dict containing the model key - model file path pairs. + """ + model = self.model if 'model' not in kwargs else kwargs.pop('model') + device_name = 'cpu' if 'device' not in kwargs else kwargs.pop('device') + model_bin_file = os.path.join(model.model_dir, + ModelFile.TORCH_MODEL_BIN_FILE) + if os.path.exists(model_bin_file): + checkpoint = torch.load(model_bin_file, map_location='cpu') + model.load_state_dict(checkpoint) + onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE) + + with torch.no_grad(): + model.eval() + device = torch.device(device_name) + model.to(device) + model_script = torch.jit.script(model) + fbank_input = torch.zeros((1, 3, 120), dtype=torch.float32) + torch.onnx.export( + model_script, + fbank_input, + onnx_file, + opset_version=opset, + input_names=[INPUT_NAME], + output_names=[OUTPUT_NAME], + dynamic_axes={ + INPUT_NAME: { + 0: 'batch_size', + 1: 'number_of_frame' + }, + OUTPUT_NAME: { + 0: 'batch_size', + 1: 'number_of_frame' + } + }) + return {'model': onnx_file} diff --git a/tests/export/test_export_speech_signal_process.py b/tests/export/test_export_speech_signal_process.py new file mode 100644 index 00000000..d3f6fe14 --- /dev/null +++ b/tests/export/test_export_speech_signal_process.py @@ -0,0 +1,83 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. + +import os +import pickle +import shutil +import tempfile +import unittest + +import torch + +from modelscope.exporters import Exporter +from modelscope.models import Model +from modelscope.utils.logger import get_logger +from modelscope.utils.regress_test_utils import (compare_arguments_nested, + numpify_tensor_nested) +from modelscope.utils.test_utils import test_level + +INPUT_PKL = 'data/test/audios/input.pkl' + +INPUT_NAME = 'input' +OUTPUT_NAME = 'output' + +logger = get_logger() + + +class ExportSpeechSignalProcessTest(unittest.TestCase): + + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + self.tmp_dir = tempfile.TemporaryDirectory().name + if not os.path.exists(self.tmp_dir): + os.makedirs(self.tmp_dir) + + def tearDown(self): + shutil.rmtree(self.tmp_dir) + super().tearDown() + + @unittest.skipUnless(test_level() >= 0, 'skip test in current test level') + def test_export_ans_dfsmn(self): + model_id = 'damo/speech_dfsmn_ans_psm_48k_causal' + model = Model.from_pretrained(model_id) + onnx_info = Exporter.from_model(model).export_onnx( + output_dir=self.tmp_dir) + + with open(os.path.join(os.getcwd(), INPUT_PKL), 'rb') as f: + fbank_input = pickle.load(f).cpu() + self.assertTrue( + self._validate_onnx_model(fbank_input, model, onnx_info['model']), + 'export onnx failed because of validation error.') + + @staticmethod + def _validate_onnx_model(dummy_inputs, model, output): + try: + import onnx + import onnxruntime as ort + except ImportError: + logger.warning( + 'Cannot validate the exported onnx file, because ' + 'the installation of onnx or onnxruntime cannot be found') + return + onnx_model = onnx.load(output) + onnx.checker.check_model(onnx_model) + ort_session = ort.InferenceSession(output) + with torch.no_grad(): + model.eval() + outputs_origin = model.forward(dummy_inputs) + outputs_origin = numpify_tensor_nested(outputs_origin) + + input_feed = {INPUT_NAME: dummy_inputs.numpy()} + outputs = ort_session.run( + None, + input_feed, + ) + outputs = numpify_tensor_nested(outputs[0]) + + print(outputs) + print(outputs_origin) + return compare_arguments_nested('Onnx model output match failed', + outputs, outputs_origin) + + +if __name__ == '__main__': + unittest.main()