2022-06-27 11:59:44 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
import os
|
|
|
|
|
import shutil
|
|
|
|
|
import unittest
|
2022-07-25 22:37:15 +08:00
|
|
|
from typing import Any, Dict, List, Union
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-08-04 11:49:26 +08:00
|
|
|
import numpy as np
|
|
|
|
|
import soundfile
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-08-04 11:49:26 +08:00
|
|
|
from modelscope.outputs import OutputKeys
|
2022-06-27 11:59:44 +08:00
|
|
|
from modelscope.pipelines import pipeline
|
2022-07-25 22:37:15 +08:00
|
|
|
from modelscope.utils.constant import ColorCodes, Tasks
|
2022-09-08 14:08:51 +08:00
|
|
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
2022-07-25 22:37:15 +08:00
|
|
|
from modelscope.utils.logger import get_logger
|
|
|
|
|
from modelscope.utils.test_utils import download_and_untar, test_level
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
logger = get_logger()
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-06-30 22:40:38 +08:00
|
|
|
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
|
|
|
|
|
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
|
2022-06-27 11:59:44 +08:00
|
|
|
|
|
|
|
|
POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
|
|
|
|
|
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
|
|
|
|
|
|
|
|
|
|
NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
|
|
|
|
|
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'
|
|
|
|
|
|
|
|
|
|
|
2022-09-08 14:08:51 +08:00
|
|
|
class KeyWordSpottingTest(unittest.TestCase, DemoCompatibilityCheck):
|
2022-07-25 22:37:15 +08:00
|
|
|
action_info = {
|
|
|
|
|
'test_run_with_wav': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
|
2022-07-25 22:37:15 +08:00
|
|
|
'checking_value': '小云小云',
|
|
|
|
|
'example': {
|
|
|
|
|
'wav_count':
|
|
|
|
|
1,
|
2022-08-04 11:49:26 +08:00
|
|
|
'kws_type':
|
2022-07-25 22:37:15 +08:00
|
|
|
'wav',
|
|
|
|
|
'kws_list': [{
|
|
|
|
|
'keyword': '小云小云',
|
|
|
|
|
'offset': 5.76,
|
|
|
|
|
'length': 9.132938,
|
|
|
|
|
'confidence': 0.990368
|
|
|
|
|
}]
|
|
|
|
|
}
|
|
|
|
|
},
|
2022-08-04 11:49:26 +08:00
|
|
|
'test_run_with_pcm': {
|
|
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
|
|
|
|
|
'checking_value': '小云小云',
|
|
|
|
|
'example': {
|
|
|
|
|
'wav_count':
|
|
|
|
|
1,
|
|
|
|
|
'kws_type':
|
|
|
|
|
'pcm',
|
|
|
|
|
'kws_list': [{
|
|
|
|
|
'keyword': '小云小云',
|
|
|
|
|
'offset': 5.76,
|
|
|
|
|
'length': 9.132938,
|
|
|
|
|
'confidence': 0.990368
|
|
|
|
|
}]
|
|
|
|
|
}
|
|
|
|
|
},
|
2022-07-25 22:37:15 +08:00
|
|
|
'test_run_with_wav_by_customized_keywords': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'checking_item': [OutputKeys.KWS_LIST, 0, 'keyword'],
|
2022-07-25 22:37:15 +08:00
|
|
|
'checking_value': '播放音乐',
|
|
|
|
|
'example': {
|
|
|
|
|
'wav_count':
|
|
|
|
|
1,
|
2022-08-04 11:49:26 +08:00
|
|
|
'kws_type':
|
2022-07-25 22:37:15 +08:00
|
|
|
'wav',
|
|
|
|
|
'kws_list': [{
|
|
|
|
|
'keyword': '播放音乐',
|
|
|
|
|
'offset': 0.87,
|
|
|
|
|
'length': 2.158313,
|
|
|
|
|
'confidence': 0.646237
|
|
|
|
|
}]
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'test_run_with_pos_testsets': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'checking_item': ['recall'],
|
2022-07-25 22:37:15 +08:00
|
|
|
'example': {
|
|
|
|
|
'wav_count': 450,
|
2022-08-04 11:49:26 +08:00
|
|
|
'kws_type': 'pos_testsets',
|
2022-07-25 22:37:15 +08:00
|
|
|
'wav_time': 3013.75925,
|
|
|
|
|
'keywords': ['小云小云'],
|
|
|
|
|
'recall': 0.953333,
|
|
|
|
|
'detected_count': 429,
|
|
|
|
|
'rejected_count': 21,
|
|
|
|
|
'rejected': ['yyy.wav', 'zzz.wav']
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'test_run_with_neg_testsets': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'checking_item': ['fa_rate'],
|
2022-07-25 22:37:15 +08:00
|
|
|
'example': {
|
|
|
|
|
'wav_count':
|
|
|
|
|
751,
|
2022-08-04 11:49:26 +08:00
|
|
|
'kws_type':
|
2022-07-25 22:37:15 +08:00
|
|
|
'neg_testsets',
|
|
|
|
|
'wav_time':
|
|
|
|
|
3572.180813,
|
|
|
|
|
'keywords': ['小云小云'],
|
|
|
|
|
'fa_rate':
|
|
|
|
|
0.001332,
|
|
|
|
|
'fa_per_hour':
|
|
|
|
|
1.007788,
|
|
|
|
|
'detected_count':
|
|
|
|
|
1,
|
|
|
|
|
'rejected_count':
|
|
|
|
|
750,
|
|
|
|
|
'detected': [{
|
|
|
|
|
'6.wav': {
|
|
|
|
|
'confidence': '0.321170',
|
|
|
|
|
'keyword': '小云小云'
|
|
|
|
|
}
|
|
|
|
|
}]
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
'test_run_with_roc': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'checking_item': ['keywords', 0],
|
2022-07-25 22:37:15 +08:00
|
|
|
'checking_value': '小云小云',
|
|
|
|
|
'example': {
|
2022-08-04 11:49:26 +08:00
|
|
|
'kws_type':
|
2022-07-25 22:37:15 +08:00
|
|
|
'roc',
|
|
|
|
|
'keywords': ['小云小云'],
|
|
|
|
|
'小云小云': [{
|
|
|
|
|
'threshold': 0.0,
|
|
|
|
|
'recall': 0.953333,
|
|
|
|
|
'fa_per_hour': 1.007788
|
|
|
|
|
}, {
|
|
|
|
|
'threshold': 0.001,
|
|
|
|
|
'recall': 0.953333,
|
|
|
|
|
'fa_per_hour': 1.007788
|
|
|
|
|
}, {
|
|
|
|
|
'threshold': 0.999,
|
|
|
|
|
'recall': 0.004444,
|
|
|
|
|
'fa_per_hour': 0.0
|
|
|
|
|
}]
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
2022-06-27 11:59:44 +08:00
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
2022-08-08 13:56:02 +08:00
|
|
|
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyun'
|
2022-06-27 11:59:44 +08:00
|
|
|
self.workspace = os.path.join(os.getcwd(), '.tmp')
|
|
|
|
|
if not os.path.exists(self.workspace):
|
|
|
|
|
os.mkdir(self.workspace)
|
|
|
|
|
|
|
|
|
|
def tearDown(self) -> None:
|
2022-07-25 22:37:15 +08:00
|
|
|
# remove workspace dir (.tmp)
|
2022-08-04 11:49:26 +08:00
|
|
|
shutil.rmtree(self.workspace, ignore_errors=True)
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
def run_pipeline(self,
|
|
|
|
|
model_id: str,
|
2022-08-04 11:49:26 +08:00
|
|
|
audio_in: Union[List[str], str, bytes],
|
2022-07-25 22:37:15 +08:00
|
|
|
keywords: List[str] = None) -> Dict[str, Any]:
|
2022-06-27 11:59:44 +08:00
|
|
|
kwsbp_16k_pipline = pipeline(
|
2022-08-08 13:56:02 +08:00
|
|
|
task=Tasks.keyword_spotting, model=model_id)
|
2022-07-25 22:37:15 +08:00
|
|
|
|
2022-08-04 11:49:26 +08:00
|
|
|
kws_result = kwsbp_16k_pipline(audio_in=audio_in, keywords=keywords)
|
2022-07-25 22:37:15 +08:00
|
|
|
|
|
|
|
|
return kws_result
|
|
|
|
|
|
2022-08-04 11:49:26 +08:00
|
|
|
def log_error(self, functions: str, result: Dict[str, Any]) -> None:
|
2022-07-25 22:37:15 +08:00
|
|
|
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
|
|
|
|
|
+ ColorCodes.END)
|
|
|
|
|
logger.error(ColorCodes.MAGENTA + functions
|
|
|
|
|
+ ' correct result example: ' + ColorCodes.YELLOW
|
|
|
|
|
+ str(self.action_info[functions]['example'])
|
|
|
|
|
+ ColorCodes.END)
|
|
|
|
|
|
|
|
|
|
raise ValueError('kws result is mismatched')
|
|
|
|
|
|
2022-08-04 11:49:26 +08:00
|
|
|
def check_result(self, functions: str, result: Dict[str, Any]) -> None:
|
|
|
|
|
result_item = result
|
|
|
|
|
check_list = self.action_info[functions]['checking_item']
|
|
|
|
|
for check_item in check_list:
|
|
|
|
|
result_item = result_item[check_item]
|
|
|
|
|
if result_item is None or result_item == 'None':
|
|
|
|
|
self.log_error(functions, result)
|
|
|
|
|
|
|
|
|
|
if self.action_info[functions].__contains__('checking_value'):
|
|
|
|
|
check_value = self.action_info[functions]['checking_value']
|
|
|
|
|
if result_item != check_value:
|
|
|
|
|
self.log_error(functions, result)
|
|
|
|
|
|
|
|
|
|
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
|
|
|
|
|
+ ColorCodes.END)
|
|
|
|
|
if functions == 'test_run_with_roc':
|
|
|
|
|
find_keyword = result['keywords'][0]
|
|
|
|
|
keyword_list = result[find_keyword]
|
|
|
|
|
for item in iter(keyword_list):
|
|
|
|
|
threshold: float = item['threshold']
|
|
|
|
|
recall: float = item['recall']
|
|
|
|
|
fa_per_hour: float = item['fa_per_hour']
|
|
|
|
|
logger.info(ColorCodes.YELLOW + ' threshold:' + str(threshold)
|
|
|
|
|
+ ' recall:' + str(recall) + ' fa_per_hour:'
|
|
|
|
|
+ str(fa_per_hour) + ColorCodes.END)
|
2022-07-25 22:37:15 +08:00
|
|
|
else:
|
2022-08-04 11:49:26 +08:00
|
|
|
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
|
|
|
|
|
|
|
|
|
|
def wav2bytes(self, wav_file) -> bytes:
|
|
|
|
|
audio, fs = soundfile.read(wav_file)
|
|
|
|
|
|
|
|
|
|
# float32 -> int16
|
|
|
|
|
audio = np.asarray(audio)
|
|
|
|
|
dtype = np.dtype('int16')
|
|
|
|
|
i = np.iinfo(dtype)
|
|
|
|
|
abs_max = 2**(i.bits - 1)
|
|
|
|
|
offset = i.min + abs_max
|
|
|
|
|
audio = (audio * abs_max + offset).clip(i.min, i.max).astype(dtype)
|
|
|
|
|
|
|
|
|
|
# int16(PCM_16) -> byte
|
|
|
|
|
audio = audio.tobytes()
|
|
|
|
|
return audio
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
|
|
|
def test_run_with_wav(self):
|
|
|
|
|
kws_result = self.run_pipeline(
|
2022-08-04 11:49:26 +08:00
|
|
|
model_id=self.model_id, audio_in=POS_WAV_FILE)
|
|
|
|
|
self.check_result('test_run_with_wav', kws_result)
|
|
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
|
|
|
def test_run_with_pcm(self):
|
|
|
|
|
audio = self.wav2bytes(os.path.join(os.getcwd(), POS_WAV_FILE))
|
|
|
|
|
|
|
|
|
|
kws_result = self.run_pipeline(model_id=self.model_id, audio_in=audio)
|
|
|
|
|
self.check_result('test_run_with_pcm', kws_result)
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-06-30 22:40:38 +08:00
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
|
|
|
|
def test_run_with_wav_by_customized_keywords(self):
|
|
|
|
|
keywords = [{'keyword': '播放音乐'}]
|
|
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
kws_result = self.run_pipeline(
|
|
|
|
|
model_id=self.model_id,
|
2022-08-04 11:49:26 +08:00
|
|
|
audio_in=BOFANGYINYUE_WAV_FILE,
|
2022-06-30 22:40:38 +08:00
|
|
|
keywords=keywords)
|
2022-08-04 11:49:26 +08:00
|
|
|
self.check_result('test_run_with_wav_by_customized_keywords',
|
|
|
|
|
kws_result)
|
2022-06-30 22:40:38 +08:00
|
|
|
|
2022-06-27 11:59:44 +08:00
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
|
|
|
def test_run_with_pos_testsets(self):
|
2022-07-25 22:37:15 +08:00
|
|
|
wav_file_path = download_and_untar(
|
|
|
|
|
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
|
|
|
|
|
self.workspace)
|
2022-08-04 11:49:26 +08:00
|
|
|
audio_list = [wav_file_path, None]
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
kws_result = self.run_pipeline(
|
2022-08-04 11:49:26 +08:00
|
|
|
model_id=self.model_id, audio_in=audio_list)
|
|
|
|
|
self.check_result('test_run_with_pos_testsets', kws_result)
|
2022-06-27 11:59:44 +08:00
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
|
|
|
|
|
def test_run_with_neg_testsets(self):
|
2022-07-25 22:37:15 +08:00
|
|
|
wav_file_path = download_and_untar(
|
|
|
|
|
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
|
|
|
|
|
self.workspace)
|
2022-08-04 11:49:26 +08:00
|
|
|
audio_list = [None, wav_file_path]
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-07-25 22:37:15 +08:00
|
|
|
kws_result = self.run_pipeline(
|
2022-08-04 11:49:26 +08:00
|
|
|
model_id=self.model_id, audio_in=audio_list)
|
|
|
|
|
self.check_result('test_run_with_neg_testsets', kws_result)
|
2022-06-27 11:59:44 +08:00
|
|
|
|
|
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
|
|
|
|
def test_run_with_roc(self):
|
2022-07-25 22:37:15 +08:00
|
|
|
pos_file_path = download_and_untar(
|
|
|
|
|
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
|
|
|
|
|
self.workspace)
|
|
|
|
|
neg_file_path = download_and_untar(
|
|
|
|
|
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
|
|
|
|
|
self.workspace)
|
2022-08-04 11:49:26 +08:00
|
|
|
audio_list = [pos_file_path, neg_file_path]
|
2022-07-25 22:37:15 +08:00
|
|
|
|
|
|
|
|
kws_result = self.run_pipeline(
|
2022-08-04 11:49:26 +08:00
|
|
|
model_id=self.model_id, audio_in=audio_list)
|
|
|
|
|
self.check_result('test_run_with_roc', kws_result)
|
2022-06-27 11:59:44 +08:00
|
|
|
|
2022-09-08 14:08:51 +08:00
|
|
|
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
|
|
|
|
|
def test_demo_compatibility(self):
|
|
|
|
|
self.compatibility_check()
|
|
|
|
|
|
2022-06-27 11:59:44 +08:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|