[to #42322933] Update tts task inputs

Refactor tts task inputs
        Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/9412937
This commit is contained in:
jiaqi.sjq
2022-07-18 17:50:59 +08:00
parent 6b3325088e
commit a17f29ce54
3 changed files with 10 additions and 15 deletions

View File

@@ -25,22 +25,19 @@ class TextToSpeechSambertHifiganPipeline(Pipeline):
"""
super().__init__(model=model, **kwargs)
def forward(self, inputs: Dict[str, str]) -> Dict[str, np.ndarray]:
def forward(self, input: str, **forward_params) -> Dict[str, np.ndarray]:
"""synthesis text from inputs with pipeline
Args:
inputs (Dict[str, str]): a dictionary that key is the name of
certain testcase and value is the text to synthesis.
input (str): text to synthesis
forward_params: valid param is 'voice' used to setting speaker vocie
Returns:
Dict[str, np.ndarray]: a dictionary with key and value. The key
is the same as inputs' key which is the label of the testcase
and the value is the pcm audio data.
Dict[str, np.ndarray]: {OutputKeys.OUTPUT_PCM : np.ndarray(16bit pcm data)}
"""
output_wav = {}
for label, text in inputs.items():
output_wav[label] = self.model.forward(text, inputs.get('voice'))
output_wav = self.model.forward(input, forward_params.get('voice'))
return {OutputKeys.OUTPUT_PCM: output_wav}
def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
def postprocess(self, inputs: Dict[str, Any],
**postprocess_params) -> Dict[str, Any]:
return inputs
def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:

View File

@@ -10,7 +10,7 @@ nara_wpe
numpy<=1.18
protobuf>3,<=3.20
ptflops
pytorch_wavelets==1.3.0
pytorch_wavelets
PyWavelets>=1.0.0
scikit-learn
SoundFile>0.10

View File

@@ -24,7 +24,6 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
@unittest.skipUnless(test_level() >= 1, 'skip test in current test level')
def test_pipeline(self):
single_test_case_label = 'test_case_label_0'
text = '今天北京天气怎么样?'
model_id = 'damo/speech_sambert-hifigan_tts_zhcn_16k'
voice = 'zhitian_emo'
@@ -32,10 +31,9 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase):
sambert_hifigan_tts = pipeline(
task=Tasks.text_to_speech, model=model_id)
self.assertTrue(sambert_hifigan_tts is not None)
inputs = {single_test_case_label: text, 'voice': voice}
output = sambert_hifigan_tts(inputs)
output = sambert_hifigan_tts(input=text, voice=voice)
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
pcm = output[OutputKeys.OUTPUT_PCM][single_test_case_label]
pcm = output[OutputKeys.OUTPUT_PCM]
write('output.wav', 16000, pcm)