Add split_sentence func to BaseTTS

This commit is contained in:
yuxumin
2023-12-16 00:28:18 +08:00
parent 47942d50bf
commit 2acbfe1fe2
3 changed files with 111 additions and 44 deletions

67
api.py
View File

@@ -1,16 +1,15 @@
import torch
import torch.nn as nn
import utils
from models import SynthesizerTrn
import torchaudio
import commons
import os
from mel_processing import spectrogram_torch, spectrogram_torch_conv
import librosa
import numpy as np
from text import text_to_sequence
import re
import soundfile
import utils
import commons
import os
import librosa
from text import text_to_sequence
from mel_processing import spectrogram_torch
from models import SynthesizerTrn
class OpenVoiceBaseClass(object):
def __init__(self,
@@ -53,23 +52,47 @@ class BaseSpeakerTTS(OpenVoiceBaseClass):
text_norm = torch.LongTensor(text_norm)
return text_norm
@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05)/speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
@staticmethod
def split_sentences_into_pieces(text):
texts = utils.split_sentences_latin(text)
print(" > Text splitted to sentences.")
print('\n'.join(texts))
print(" > ===========================")
return texts
def tts(self, text, output_path, speaker, language='English', speed=1.0):
mark = self.language_marks.get(language.lower(), None)
assert mark is not None, f"language {language} is not supported"
text = re.sub(r'([a-z])([A-Z])', r'\1 \2', text)
text = mark + text + mark
stn_tst = self.get_text(text, self.hps, False)
device = self.device
speaker_id = self.hps.speakers[speaker]
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
texts = self.split_sentences_into_pieces(text)
audio_list = []
for t in texts:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
t = mark + t + mark
stn_tst = self.get_text(t, self.hps, False)
device = self.device
speaker_id = self.hps.speakers[speaker]
with torch.no_grad():
x_tst = stn_tst.unsqueeze(0).to(device)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
sid = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=0.667, noise_scale_w=0.6,
length_scale=1.0 / speed)[0][0, 0].data.cpu().float().numpy()
audio_list.append(audio)
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None:
return audio.numpy()
return audio
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)

View File

@@ -1,26 +1,9 @@
import os
from pydub import AudioSegment
import string
from faster_whisper import WhisperModel
import glob
import random
import torch
import numpy as np
from glob import glob
import librosa
from mel_processing import spectrogram_torch
def is_english(s):
valid_chars = string.ascii_letters + string.digits + string.whitespace + string.punctuation
return all(char in valid_chars for char in s)
def is_chinese(sentence):
valid_chars = string.whitespace + string.punctuation
for char in sentence:
if (char < '\u4e00' or char > '\u9fff') and char not in valid_chars:
return False
return True
from pydub import AudioSegment
from faster_whisper import WhisperModel
model_size = "medium"
# Run on GPU with FP16
@@ -102,4 +85,4 @@ def get_se(audio_path, vc_model, target_dir='processed'):
if len(audio_segs) == 0:
raise NotImplementedError('No audio segments found!')
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name
return vc_model.extract_se(audio_segs, se_save_path=se_path), audio_name

View File

@@ -1,6 +1,6 @@
import re
import json
import numpy as np
import torch
def get_hparams_from_file(config_path):
@@ -72,4 +72,65 @@ def bits_to_string(bits_array):
# Convert ASCII values to characters
output_string = ''.join(chr(value) for value in ascii_values)
return output_string
return output_string
def split_sentences_latin(text, min_len=10):
"""Split Long sentences into list of short ones
Args:
str: Input sentences.
Returns:
List[str]: list of output sentences.
"""
# deal with dirty sentences
text = re.sub('[。!?;]', '.', text)
text = re.sub('[]', ',', text)
text = re.sub('[“”]', '"', text)
text = re.sub('[]', "'", text)
text = re.sub(r"[\<\>\(\)\[\]\"\«\»]+", "", text)
text = re.sub('[\n\t ]+', ' ', text)
text = re.sub('([,.!?;])', r'\1 $#!', text)
# split
sentences = [s.strip() for s in text.split('$#!')]
if len(sentences[-1]) == 0: del sentences[-1]
new_sentences = []
new_sent = []
count_len = 0
for ind, sent in enumerate(sentences):
# print(sent)
new_sent.append(sent)
count_len += len(sent.split(" "))
if count_len > min_len or ind == len(sentences) - 1:
count_len = 0
new_sentences.append(' '.join(new_sent))
new_sent = []
return merge_short_sentences_latin(new_sentences)
def merge_short_sentences_latin(sens):
"""Avoid short sentences by merging them with the following sentence.
Args:
List[str]: list of input sentences.
Returns:
List[str]: list of output sentences.
"""
sens_out = []
for s in sens:
# If the previous sentense is too short, merge them with
# the current sentence.
if len(sens_out) > 0 and len(sens_out[-1].split(" ")) <= 2:
sens_out[-1] = sens_out[-1] + " " + s
else:
sens_out.append(s)
try:
if len(sens_out[-1].split(" ")) <= 2:
sens_out[-2] = sens_out[-2] + " " + sens_out[-1]
sens_out.pop(-1)
except:
pass
return sens_out