mirror of
https://github.com/modelscope/modelscope.git
synced 2026-02-24 12:10:09 +01:00
[to #42322933] add ONNX exporter for ans dfsmn
This commit is contained in:
Submodule data/test updated: c117008caa...0a61e00de4
22
modelscope/exporters/audio/__init__.py
Normal file
22
modelscope/exporters/audio/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .ans_dfsmn_exporter import ANSDFSMNExporter
|
||||
else:
|
||||
_import_structure = {
|
||||
'ans_dfsmn_exporter': ['ANSDFSMNExporter'],
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
62
modelscope/exporters/audio/ans_dfsmn_exporter.py
Normal file
62
modelscope/exporters/audio/ans_dfsmn_exporter.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.exporters.builder import EXPORTERS
|
||||
from modelscope.exporters.torch_model_exporter import TorchModelExporter
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
|
||||
INPUT_NAME = 'input'
|
||||
OUTPUT_NAME = 'output'
|
||||
|
||||
|
||||
@EXPORTERS.register_module(
|
||||
Tasks.acoustic_noise_suppression, module_name=Models.speech_dfsmn_ans)
|
||||
class ANSDFSMNExporter(TorchModelExporter):
|
||||
|
||||
def export_onnx(self, output_dir: str, opset=9, **kwargs):
|
||||
"""Export the model as onnx format files.
|
||||
|
||||
Args:
|
||||
output_dir: The output dir.
|
||||
opset: The version of the ONNX operator set to use.
|
||||
kwargs:
|
||||
device: The device used to forward.
|
||||
Returns:
|
||||
A dict containing the model key - model file path pairs.
|
||||
"""
|
||||
model = self.model if 'model' not in kwargs else kwargs.pop('model')
|
||||
device_name = 'cpu' if 'device' not in kwargs else kwargs.pop('device')
|
||||
model_bin_file = os.path.join(model.model_dir,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
if os.path.exists(model_bin_file):
|
||||
checkpoint = torch.load(model_bin_file, map_location='cpu')
|
||||
model.load_state_dict(checkpoint)
|
||||
onnx_file = os.path.join(output_dir, ModelFile.ONNX_MODEL_FILE)
|
||||
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
device = torch.device(device_name)
|
||||
model.to(device)
|
||||
model_script = torch.jit.script(model)
|
||||
fbank_input = torch.zeros((1, 3, 120), dtype=torch.float32)
|
||||
torch.onnx.export(
|
||||
model_script,
|
||||
fbank_input,
|
||||
onnx_file,
|
||||
opset_version=opset,
|
||||
input_names=[INPUT_NAME],
|
||||
output_names=[OUTPUT_NAME],
|
||||
dynamic_axes={
|
||||
INPUT_NAME: {
|
||||
0: 'batch_size',
|
||||
1: 'number_of_frame'
|
||||
},
|
||||
OUTPUT_NAME: {
|
||||
0: 'batch_size',
|
||||
1: 'number_of_frame'
|
||||
}
|
||||
})
|
||||
return {'model': onnx_file}
|
||||
83
tests/export/test_export_speech_signal_process.py
Normal file
83
tests/export/test_export_speech_signal_process.py
Normal file
@@ -0,0 +1,83 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from modelscope.exporters import Exporter
|
||||
from modelscope.models import Model
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.regress_test_utils import (compare_arguments_nested,
|
||||
numpify_tensor_nested)
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
INPUT_PKL = 'data/test/audios/input.pkl'
|
||||
|
||||
INPUT_NAME = 'input'
|
||||
OUTPUT_NAME = 'output'
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class ExportSpeechSignalProcessTest(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_export_ans_dfsmn(self):
|
||||
model_id = 'damo/speech_dfsmn_ans_psm_48k_causal'
|
||||
model = Model.from_pretrained(model_id)
|
||||
onnx_info = Exporter.from_model(model).export_onnx(
|
||||
output_dir=self.tmp_dir)
|
||||
|
||||
with open(os.path.join(os.getcwd(), INPUT_PKL), 'rb') as f:
|
||||
fbank_input = pickle.load(f).cpu()
|
||||
self.assertTrue(
|
||||
self._validate_onnx_model(fbank_input, model, onnx_info['model']),
|
||||
'export onnx failed because of validation error.')
|
||||
|
||||
@staticmethod
|
||||
def _validate_onnx_model(dummy_inputs, model, output):
|
||||
try:
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
'Cannot validate the exported onnx file, because '
|
||||
'the installation of onnx or onnxruntime cannot be found')
|
||||
return
|
||||
onnx_model = onnx.load(output)
|
||||
onnx.checker.check_model(onnx_model)
|
||||
ort_session = ort.InferenceSession(output)
|
||||
with torch.no_grad():
|
||||
model.eval()
|
||||
outputs_origin = model.forward(dummy_inputs)
|
||||
outputs_origin = numpify_tensor_nested(outputs_origin)
|
||||
|
||||
input_feed = {INPUT_NAME: dummy_inputs.numpy()}
|
||||
outputs = ort_session.run(
|
||||
None,
|
||||
input_feed,
|
||||
)
|
||||
outputs = numpify_tensor_nested(outputs[0])
|
||||
|
||||
print(outputs)
|
||||
print(outputs_origin)
|
||||
return compare_arguments_nested('Onnx model output match failed',
|
||||
outputs, outputs_origin)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user