2022-09-20 17:49:31 +08:00
|
|
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
|
|
|
|
2022-06-20 17:23:11 +08:00
|
|
|
import unittest
|
|
|
|
|
|
|
|
|
|
# NOTICE: Tensorflow 1.15 seems not so compatible with pytorch.
|
|
|
|
|
# A segmentation fault may be raise by pytorch cpp library
|
|
|
|
|
# if 'import tensorflow' in front of 'import torch'.
|
|
|
|
|
# Puting a 'import torch' here can bypass this incompatibility.
|
|
|
|
|
import torch
|
|
|
|
|
from scipy.io.wavfile import write
|
|
|
|
|
|
2022-09-27 22:09:30 +08:00
|
|
|
from modelscope.models import Model
|
2022-07-14 16:25:55 +08:00
|
|
|
from modelscope.outputs import OutputKeys
|
2022-06-20 17:23:11 +08:00
|
|
|
from modelscope.pipelines import pipeline
|
2022-08-06 12:22:17 +08:00
|
|
|
from modelscope.utils.constant import Tasks
|
2022-09-08 14:08:51 +08:00
|
|
|
from modelscope.utils.demo_utils import DemoCompatibilityCheck
|
2022-06-20 17:23:11 +08:00
|
|
|
from modelscope.utils.logger import get_logger
|
2022-06-23 16:55:48 +08:00
|
|
|
from modelscope.utils.test_utils import test_level
|
2022-06-20 17:23:11 +08:00
|
|
|
|
2022-07-06 13:20:04 +08:00
|
|
|
import tensorflow as tf # isort:skip
|
|
|
|
|
|
2022-06-20 17:23:11 +08:00
|
|
|
logger = get_logger()
|
|
|
|
|
|
|
|
|
|
|
2022-09-08 14:08:51 +08:00
|
|
|
class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase,
|
|
|
|
|
DemoCompatibilityCheck):
|
|
|
|
|
|
|
|
|
|
def setUp(self) -> None:
|
|
|
|
|
self.task = Tasks.text_to_speech
|
2022-10-11 17:22:58 +08:00
|
|
|
zhcn_text = '今天北京天气怎么样'
|
|
|
|
|
en_text = 'How is the weather in Beijing?'
|
|
|
|
|
zhcn_voice = ['zhitian_emo', 'zhizhe_emo', 'zhiyan_emo', 'zhibei_emo']
|
|
|
|
|
enus_voice = ['andy', 'annie']
|
|
|
|
|
engb_voice = ['luca', 'luna']
|
|
|
|
|
self.tts_test_cases = []
|
|
|
|
|
for voice in zhcn_voice:
|
|
|
|
|
model_id = 'damo/speech_sambert-hifigan_tts_%s_%s_16k' % (voice,
|
|
|
|
|
'zh-cn')
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': voice,
|
|
|
|
|
'model_id': model_id,
|
|
|
|
|
'text': zhcn_text
|
|
|
|
|
})
|
|
|
|
|
for voice in enus_voice:
|
|
|
|
|
model_id = 'damo/speech_sambert-hifigan_tts_%s_%s_16k' % (voice,
|
|
|
|
|
'en-us')
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': voice,
|
|
|
|
|
'model_id': model_id,
|
|
|
|
|
'text': en_text
|
|
|
|
|
})
|
|
|
|
|
for voice in engb_voice:
|
|
|
|
|
model_id = 'damo/speech_sambert-hifigan_tts_%s_%s_16k' % (voice,
|
|
|
|
|
'en-gb')
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': voice,
|
|
|
|
|
'model_id': model_id,
|
|
|
|
|
'text': en_text
|
|
|
|
|
})
|
|
|
|
|
zhcn_model_id = 'damo/speech_sambert-hifigan_tts_zh-cn_16k'
|
|
|
|
|
enus_model_id = 'damo/speech_sambert-hifigan_tts_en-us_16k'
|
|
|
|
|
engb_model_id = 'damo/speech_sambert-hifigan_tts_en-gb_16k'
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': 'zhcn',
|
|
|
|
|
'model_id': zhcn_model_id,
|
|
|
|
|
'text': zhcn_text
|
|
|
|
|
})
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': 'enus',
|
|
|
|
|
'model_id': enus_model_id,
|
|
|
|
|
'text': en_text
|
|
|
|
|
})
|
|
|
|
|
self.tts_test_cases.append({
|
|
|
|
|
'voice': 'engb',
|
|
|
|
|
'model_id': engb_model_id,
|
|
|
|
|
'text': en_text
|
|
|
|
|
})
|
2022-06-20 17:23:11 +08:00
|
|
|
|
2022-08-06 12:22:17 +08:00
|
|
|
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
2022-06-20 17:23:11 +08:00
|
|
|
def test_pipeline(self):
|
2022-10-11 17:22:58 +08:00
|
|
|
for case in self.tts_test_cases:
|
|
|
|
|
logger.info('test %s' % case['voice'])
|
|
|
|
|
model = Model.from_pretrained(
|
|
|
|
|
model_name_or_path=case['model_id'], revision='pytorch_am')
|
|
|
|
|
sambert_hifigan_tts = pipeline(task=self.task, model=model)
|
|
|
|
|
self.assertTrue(sambert_hifigan_tts is not None)
|
|
|
|
|
output = sambert_hifigan_tts(input=case['text'])
|
|
|
|
|
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
|
|
|
|
|
pcm = output[OutputKeys.OUTPUT_PCM]
|
|
|
|
|
write('output_%s.wav' % case['voice'], 16000, pcm)
|
2022-06-20 17:23:11 +08:00
|
|
|
|
2022-09-09 14:56:15 +08:00
|
|
|
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
2022-09-08 14:08:51 +08:00
|
|
|
def test_demo_compatibility(self):
|
|
|
|
|
self.compatibility_check()
|
|
|
|
|
|
2022-06-20 17:23:11 +08:00
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
|
unittest.main()
|