Files
modelscope/tests/pipelines/test_key_word_spotting.py

250 lines
9.2 KiB
Python
Raw Normal View History

# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tarfile
import unittest
from typing import Any, Dict, List, Union
import requests
from modelscope.pipelines import pipeline
from modelscope.utils.constant import ColorCodes, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import download_and_untar, test_level
logger = get_logger()
POS_WAV_FILE = 'data/test/audios/kws_xiaoyunxiaoyun.wav'
BOFANGYINYUE_WAV_FILE = 'data/test/audios/kws_bofangyinyue.wav'
POS_TESTSETS_FILE = 'pos_testsets.tar.gz'
POS_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/pos_testsets.tar.gz'
NEG_TESTSETS_FILE = 'neg_testsets.tar.gz'
NEG_TESTSETS_URL = 'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/KWS/neg_testsets.tar.gz'
class KeyWordSpottingTest(unittest.TestCase):
action_info = {
'test_run_with_wav': {
'checking_item': 'kws_list',
'checking_value': '小云小云',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '小云小云',
'offset': 5.76,
'length': 9.132938,
'confidence': 0.990368
}]
}
},
'test_run_with_wav_by_customized_keywords': {
'checking_item': 'kws_list',
'checking_value': '播放音乐',
'example': {
'wav_count':
1,
'kws_set':
'wav',
'kws_list': [{
'keyword': '播放音乐',
'offset': 0.87,
'length': 2.158313,
'confidence': 0.646237
}]
}
},
'test_run_with_pos_testsets': {
'checking_item': 'recall',
'example': {
'wav_count': 450,
'kws_set': 'pos_testsets',
'wav_time': 3013.75925,
'keywords': ['小云小云'],
'recall': 0.953333,
'detected_count': 429,
'rejected_count': 21,
'rejected': ['yyy.wav', 'zzz.wav']
}
},
'test_run_with_neg_testsets': {
'checking_item': 'fa_rate',
'example': {
'wav_count':
751,
'kws_set':
'neg_testsets',
'wav_time':
3572.180813,
'keywords': ['小云小云'],
'fa_rate':
0.001332,
'fa_per_hour':
1.007788,
'detected_count':
1,
'rejected_count':
750,
'detected': [{
'6.wav': {
'confidence': '0.321170',
'keyword': '小云小云'
}
}]
}
},
'test_run_with_roc': {
'checking_item': 'keywords',
'checking_value': '小云小云',
'example': {
'kws_set':
'roc',
'keywords': ['小云小云'],
'小云小云': [{
'threshold': 0.0,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.001,
'recall': 0.953333,
'fa_per_hour': 1.007788
}, {
'threshold': 0.999,
'recall': 0.004444,
'fa_per_hour': 0.0
}]
}
}
}
def setUp(self) -> None:
self.model_id = 'damo/speech_charctc_kws_phone-xiaoyunxiaoyun'
self.workspace = os.path.join(os.getcwd(), '.tmp')
if not os.path.exists(self.workspace):
os.mkdir(self.workspace)
def tearDown(self) -> None:
# remove workspace dir (.tmp)
if os.path.exists(self.workspace):
shutil.rmtree(self.workspace, ignore_errors=True)
def run_pipeline(self,
model_id: str,
wav_path: Union[List[str], str],
keywords: List[str] = None) -> Dict[str, Any]:
kwsbp_16k_pipline = pipeline(
task=Tasks.auto_speech_recognition, model=model_id)
kws_result = kwsbp_16k_pipline(wav_path=wav_path, keywords=keywords)
return kws_result
def print_error(self, functions: str, result: Dict[str, Any]) -> None:
logger.error(ColorCodes.MAGENTA + functions + ': FAILED.'
+ ColorCodes.END)
logger.error(ColorCodes.MAGENTA + functions
+ ' correct result example: ' + ColorCodes.YELLOW
+ str(self.action_info[functions]['example'])
+ ColorCodes.END)
raise ValueError('kws result is mismatched')
def check_and_print_result(self, functions: str,
result: Dict[str, Any]) -> None:
if result.__contains__(self.action_info[functions]['checking_item']):
checking_item = result[self.action_info[functions]
['checking_item']]
if functions == 'test_run_with_roc':
if checking_item[0] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)
elif functions == 'test_run_with_wav':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)
elif functions == 'test_run_with_wav_by_customized_keywords':
if checking_item[0]['keyword'] != self.action_info[functions][
'checking_value']:
self.print_error(functions, result)
logger.info(ColorCodes.MAGENTA + functions + ': SUCCESS.'
+ ColorCodes.END)
if functions == 'test_run_with_roc':
find_keyword = result['keywords'][0]
keyword_list = result[find_keyword]
for item in iter(keyword_list):
threshold: float = item['threshold']
recall: float = item['recall']
fa_per_hour: float = item['fa_per_hour']
logger.info(ColorCodes.YELLOW + ' threshold:'
+ str(threshold) + ' recall:' + str(recall)
+ ' fa_per_hour:' + str(fa_per_hour)
+ ColorCodes.END)
else:
logger.info(ColorCodes.YELLOW + str(result) + ColorCodes.END)
else:
self.print_error(functions, result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav(self):
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=POS_WAV_FILE)
self.check_and_print_result('test_run_with_wav', kws_result)
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_run_with_wav_by_customized_keywords(self):
keywords = [{'keyword': '播放音乐'}]
kws_result = self.run_pipeline(
model_id=self.model_id,
wav_path=BOFANGYINYUE_WAV_FILE,
keywords=keywords)
self.check_and_print_result('test_run_with_wav_by_customized_keywords',
kws_result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_pos_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
wav_path = [wav_file_path, None]
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_pos_testsets', kws_result)
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_run_with_neg_testsets(self):
wav_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [None, wav_file_path]
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_neg_testsets', kws_result)
@unittest.skipUnless(test_level() >= 2, 'skip test in current test level')
def test_run_with_roc(self):
pos_file_path = download_and_untar(
os.path.join(self.workspace, POS_TESTSETS_FILE), POS_TESTSETS_URL,
self.workspace)
neg_file_path = download_and_untar(
os.path.join(self.workspace, NEG_TESTSETS_FILE), NEG_TESTSETS_URL,
self.workspace)
wav_path = [pos_file_path, neg_file_path]
kws_result = self.run_pipeline(
model_id=self.model_id, wav_path=wav_path)
self.check_and_print_result('test_run_with_roc', kws_result)
if __name__ == '__main__':
unittest.main()