mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-18 01:07:44 +01:00
# Speaker Diarization Dialogue Detection CR 本模型是Speaker Diarization(`audio/speaker diarization`,语音/说话人日志)任务下的一个子模块。 本次提交的是基于文本进行判断的模型,其IO和中间过程和 `nlp/text-classification` 很像,且本地模型的初始模型也是基于huggingface训练的,因此此提交中复用了部分 `nlp/text-classification` 模型的代码。为了方便后续维护以及与nlp方面代码解耦,在model、pipeline以及preprocessor中 **单独** 创建了相应模块并重新register。 Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/13269649 * start to add speaker_diarization_dialogue_detection files; Need to change constant and test * add sv/speaker_diarization_dialogue_detection to branch sv/semantic-dialogue-detection * update test case * add comments for speaker diarization dialogue detection pipelines * add outputs type and inputs type for speaker_diarization_dialogue_detection
59 lines
2.3 KiB
Python
59 lines
2.3 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
from typing import Any, Dict
|
|
|
|
import numpy as np
|
|
|
|
from modelscope.models import Model
|
|
from modelscope.pipelines import pipeline
|
|
from modelscope.utils.constant import Tasks
|
|
from modelscope.utils.logger import get_logger
|
|
from modelscope.utils.test_utils import test_level
|
|
|
|
logger = get_logger()
|
|
|
|
|
|
class SpeakerDiarizationDialogueDetectionTest(unittest.TestCase):
|
|
|
|
test_datasets = [{
|
|
'sentence':
|
|
'还有什么双面胶啥的不都是她写的吗?然后这部剧她为了写这部剧,她还说亲自去武汉。然后体验了一下穿防护服的感受。据说那个防护服要穿好几层那种。',
|
|
'label': False
|
|
}, {
|
|
'sentence':
|
|
'你们那儿小区不能用健康宝吗?不能。北一区都可以了。外面进去的就像是那个快递员儿呀,或者是外卖小哥呀,要健康宝。然后本小区的要出入证,都问有出入证吗?',
|
|
'label': True
|
|
}, {
|
|
'sentence': '侦探小说从19世纪中期开始发展。美国作家埃德加‧爱伦‧坡被认为是西方侦探小说的鼻祖。',
|
|
'label': False
|
|
}]
|
|
|
|
dialogue_detection_model_id = 'damo/speech_bert_dialogue-detetction_speaker-diarization_chinese'
|
|
|
|
def setUp(self) -> None:
|
|
self.task = Tasks.speaker_diarization_dialogue_detection
|
|
|
|
def run_pipeline(self,
|
|
model_id: str,
|
|
model_revision=None) -> Dict[str, Any]:
|
|
dialogue_detection = pipeline(
|
|
task=self.task, model=model_id, model_revision=model_revision)
|
|
outputs_list = []
|
|
for test_item in self.test_datasets:
|
|
sentence = test_item['sentence']
|
|
outputs_list.append((sentence, dialogue_detection(sentence)))
|
|
return outputs_list
|
|
|
|
@unittest.skipUnless(test_level() >= 0, 'Skip test in current test level')
|
|
def test_dialogue_detection_model(self):
|
|
logger.info('Run speaker diarization dialogue detection from modelhub')
|
|
output_list = self.run_pipeline(
|
|
model_id=self.dialogue_detection_model_id, model_revision='v0.5.3')
|
|
for sentence, result in output_list:
|
|
label = result['labels'][np.argmax(result['scores'])]
|
|
logger.info(f'Sentence = {sentence}, label = {label}')
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|