[to #42322933] support byte input feature and refine fp implementations

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11338137
This commit is contained in:
jiaqi.sjq
2023-01-09 20:56:52 +08:00
parent 9794fbf1c7
commit 453ff1dae3
12 changed files with 477 additions and 195 deletions

View File

@@ -18,7 +18,7 @@ from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import (
from modelscope.utils.logger import get_logger
DATASET_RANDOM_SEED = 1234
torch.multiprocessing.set_sharing_strategy('file_system')
logging = get_logger()
@@ -249,9 +249,27 @@ class VocDataset(KanttsDataset):
mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data),
axis=1)
# make sure the audio length and feature length are matched
wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect')
wav_data = wav_data[:len(mel_data) * self.hop_length]
# make sure mel_data length greater than batch_max_frames at least 1 frame
if mel_data.shape[0] <= self.batch_max_frames:
mel_data = np.concatenate(
(
mel_data,
np.zeros((
self.batch_max_frames - mel_data.shape[0] + 1,
mel_data.shape[1],
)),
),
axis=0,
)
wav_cache = np.zeros(
mel_data.shape[0] * self.hop_length, dtype=np.float32)
wav_cache[:len(wav_data)] = wav_data
wav_data = wav_cache
else:
# make sure the audio length and feature length are matched
wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect')
wav_data = wav_data[:len(mel_data) * self.hop_length]
assert len(mel_data) * self.hop_length == len(wav_data)
if self.allow_cache:
@@ -561,11 +579,12 @@ class AmDataset(KanttsDataset):
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(duration_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + '.npy')):
continue
f.write(line)
with open(valid_meta_file, 'w') as f:
@@ -577,62 +596,86 @@ class AmDataset(KanttsDataset):
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(duration_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
if os.path.exists(duration_dir) and not os.path.exists(
os.path.join(duration_dir, index + '.npy')):
continue
f.write(line)
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
if self.with_duration:
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
# pure linguistic info: sy|tone|syllable_flag|word_segment
lfeat_type = self.ling_unit._lfeat_type_list[0]
inputs_sy = self.padder._prepare_scalar_inputs(
[x[0][0] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
lfeat_type_index = 0
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
if self.ling_unit.using_byte():
# for byte-based model only
inputs_byte_index = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][2] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict['input_lings'] = torch.stack([inputs_byte_index], dim=2)
else:
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
inputs_sy = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][3] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict['input_lings'] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws],
dim=2)
# emotion category
lfeat_type = self.ling_unit._lfeat_type_list[4]
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
data_dict['input_emotions'] = self.padder._prepare_scalar_inputs(
[x[0][4] for x in batch],
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# speaker category
lfeat_type = self.ling_unit._lfeat_type_list[5]
lfeat_type_index = lfeat_type_index + 1
lfeat_type = self.ling_unit._lfeat_type_list[lfeat_type_index]
data_dict['input_speakers'] = self.padder._prepare_scalar_inputs(
[x[0][5] for x in batch],
[x[0][lfeat_type_index] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
@@ -645,8 +688,6 @@ class AmDataset(KanttsDataset):
0,
).long()
data_dict['input_lings'] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
data_dict['valid_input_lengths'] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”影响duration计算所以把length-1

View File

@@ -163,7 +163,7 @@ class Generator(torch.nn.Module):
else:
# transconv
up = self.transpose_upsamples[i](x)
x = rep + up
x = rep + up[:, :, :rep.shape[-1]]
xs = None
for j in range(self.num_kernels):

View File

@@ -253,15 +253,25 @@ class TextFftEncoder(nn.Module):
def __init__(self, config):
super(TextFftEncoder, self).__init__()
# linguistic unit lookup table
nb_ling_sy = config['sy']
nb_ling_tone = config['tone']
nb_ling_syllable_flag = config['syllable_flag']
nb_ling_ws = config['word_segment']
d_emb = config['embedding_dim']
self.using_byte = False
if config.get('using_byte', False):
self.using_byte = True
nb_ling_byte_index = config['byte_index']
self.byte_index_emb = nn.Embedding(nb_ling_byte_index, d_emb)
else:
# linguistic unit lookup table
nb_ling_sy = config['sy']
nb_ling_tone = config['tone']
nb_ling_syllable_flag = config['syllable_flag']
nb_ling_ws = config['word_segment']
self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
max_len = config['max_len']
d_emb = config['embedding_dim']
nb_layers = config['encoder_num_layers']
nb_heads = config['encoder_num_heads']
d_model = config['encoder_num_units']
@@ -274,11 +284,6 @@ class TextFftEncoder(nn.Module):
self.d_model = d_model
self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
position_enc = SinusoidalPositionEncoder(max_len, d_emb)
self.ling_enc = SelfAttentionEncoder(
@@ -298,20 +303,26 @@ class TextFftEncoder(nn.Module):
def forward(self, inputs_ling, masks=None, return_attns=False):
# Parse inputs_ling_seq
inputs_sy = inputs_ling[:, :, 0]
inputs_tone = inputs_ling[:, :, 1]
inputs_syllable_flag = inputs_ling[:, :, 2]
inputs_ws = inputs_ling[:, :, 3]
if self.using_byte:
inputs_byte_index = inputs_ling[:, :, 0]
byte_index_embedding = self.byte_index_emb(inputs_byte_index)
ling_embedding = byte_index_embedding
else:
inputs_sy = inputs_ling[:, :, 0]
inputs_tone = inputs_ling[:, :, 1]
inputs_syllable_flag = inputs_ling[:, :, 2]
inputs_ws = inputs_ling[:, :, 3]
# Lookup table
sy_embedding = self.sy_emb(inputs_sy)
tone_embedding = self.tone_emb(inputs_tone)
syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag)
ws_embedding = self.ws_emb(inputs_ws)
# Lookup table
sy_embedding = self.sy_emb(inputs_sy)
tone_embedding = self.tone_emb(inputs_tone)
syllable_flag_embedding = self.syllable_flag_emb(
inputs_syllable_flag)
ws_embedding = self.ws_emb(inputs_ws)
ling_embedding = (
sy_embedding + tone_embedding + syllable_flag_embedding
+ ws_embedding)
ling_embedding = (
sy_embedding + tone_embedding + syllable_flag_embedding
+ ws_embedding)
enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks,
return_attns)

View File

@@ -420,6 +420,57 @@ def compute_std(data_list, mean_vector, dims=80):
return std_vector
F0_MIN = 0.0
F0_MAX = 800.0
ENERGY_MIN = 0.0
ENERGY_MAX = 200.0
CLIP_FLOOR = 1e-3
def f0_norm_min_max(f0):
zero_idxs = np.where(f0 <= CLIP_FLOOR)[0]
res = (2 * f0 - F0_MIN - F0_MAX) / (F0_MAX - F0_MIN)
res[zero_idxs] = 0.0
return res
def f0_denorm_min_max(f0):
zero_idxs = np.where(f0 == 0.0)[0]
res = (f0 * (F0_MAX - F0_MIN) + F0_MIN + F0_MAX) / 2
res[zero_idxs] = 0.0
return res
def energy_norm_min_max(energy):
zero_idxs = np.where(energy == 0.0)[0]
res = (2 * energy - ENERGY_MIN - ENERGY_MAX) / (ENERGY_MAX - ENERGY_MIN)
res[zero_idxs] = 0.0
return res
def energy_denorm_min_max(energy):
zero_idxs = np.where(energy == 0.0)[0]
res = (energy * (ENERGY_MAX - ENERGY_MIN) + ENERGY_MIN + ENERGY_MAX) / 2
res[zero_idxs] = 0.0
return res
def norm_log(x):
zero_idxs = np.where(x <= CLIP_FLOOR)[0]
x[zero_idxs] = 1.0
res = np.log(x)
return res
def denorm_log(x):
zero_idxs = np.where(x == 0.0)[0]
res = np.exp(x)
res[zero_idxs] = 0.0
return res
def f0_norm_mean_std(x, mean, std):
zero_idxs = np.where(x == 0.0)[0]
x = (x - mean) / std

View File

@@ -110,6 +110,8 @@ def process_data(
languages[targetLang]['s2p_map_path'])
logging.info(f'phoneset_path={phoneset_path}')
# dir of plain text/sentences for training byte based model
plain_text_dir = os.path.join(voice_input_dir, 'text')
if speaker_name is None:
speaker_name = os.path.basename(voice_input_dir)
@@ -130,28 +132,35 @@ def process_data(
raw_metafile = None
# Script processor
if not skip_script:
tsc = TextScriptConvertor(
phoneset_path,
posset_path,
targetLang,
foreignLang,
f2t_map_path,
s2p_map_path,
emo_tag_path,
speaker_name,
)
tsc.process(
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
os.path.join(voice_output_dir, 'Script.xml'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
)
if os.path.exists(plain_text_dir):
TextScriptConvertor.turn_text_into_bytes(
os.path.join(plain_text_dir, 'text.txt'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
speaker_name,
)
fp_enable = False
else:
tsc = TextScriptConvertor(
phoneset_path,
posset_path,
targetLang,
foreignLang,
f2t_map_path,
s2p_map_path,
emo_tag_path,
speaker_name,
)
tsc.process(
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
os.path.join(voice_output_dir, 'Script.xml'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
)
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
# FP processor
with codecs.open(prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
# FP processor
with codecs.open(prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
if fp_enable:
FP = FpProcessor()

View File

@@ -22,8 +22,20 @@ def is_fp_line(line):
class FpProcessor:
def __init__(self):
# TODO: Add more audio processing methods.
self.res = []
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
# TODO: adjust idx judgment rule
def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
fp_category_list = ['FP', 'I', 'N']
@@ -35,15 +47,28 @@ class FpProcessor:
idx = ''
fp = ''
fp_label_dict = {}
for i in range(len(prosody_lines)):
if i % 5 == 0:
i = 0
while i < len(prosody_lines):
if len(prosody_lines[i].strip().split('\t')) == 2:
idx = prosody_lines[i].strip().split('\t')[0]
elif i % 5 == 1: # according to prosody.txt
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
for label in fp:
if label not in fp_category_list:
logging.warning('fp label not in fp_category_list')
break
i += 1
else:
fp_enable = is_fp_line(prosody_lines[i])
if fp_enable:
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
for label in fp:
if label not in fp_category_list:
logging.warning('fp label not in fp_category_list')
break
i += 4
else:
fp = [
'N' for _ in range(
len(prosody_lines[i].strip().split('\t')
[0].replace('/ ', '').replace('. ', '').split(
' ')))
]
i += 1
fp_label_dict[idx] = fp
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
@@ -76,9 +101,12 @@ class FpProcessor:
error_flag = True
out_str = out_str + this_symbol_sequence + ' '
if idx != len(fp_label_dict[uttname]):
logging.warning('{} length mismatch, length: {} '.format(
idx, len(fp_label_dict[uttname])))
# if idx != len(fp_label_dict[uttname]):
# logging.warning(
# "{} length mismatch, length: {} ".format(
# idx, len(fp_label_dict[uttname])
# )
# )
if not error_flag:
f_out.write(out_str.strip() + '\n')

View File

@@ -99,14 +99,18 @@ def format_prosody(src_prosody):
formatted_lines = []
with codecs.open(src_prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
for i in range(0, len(lines)):
line = do_character_normalization(lines[i])
if fp_enable:
if i % 5 == 1 or i % 5 == 2 or i % 5 == 3:
continue
idx = 0
while idx < len(lines):
line = do_character_normalization(lines[idx])
if len(line.strip().split('\t')) == 2:
line = do_prosody_text_normalization(line)
else:
fp_enable = is_fp_line(line)
if fp_enable:
idx += 3
continue
formatted_lines.append(line)
idx += 1
return formatted_lines

View File

@@ -4,6 +4,7 @@ import argparse
import os
import re
from bitstring import BitArray
from tqdm import tqdm
from modelscope.utils.logger import get_logger
@@ -461,3 +462,39 @@ class TextScriptConvertor:
logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
outputMetafile)
@staticmethod
def turn_text_into_bytes(plain_text_path, output_meta_file_path, speaker):
meta_lines = []
with open(plain_text_path, 'r') as in_file:
for text_line in in_file:
[sentence_id, sentence] = text_line.strip().split('\t')
sequence = []
for character in sentence:
hex_string = character.encode('utf-8').hex()
i = 0
while i < len(hex_string):
byte_hex = hex_string[i:i + 2]
bit_array = BitArray(hex=byte_hex)
integer = bit_array.uint
if integer > 255:
logging.error(
'TextScriptConverter.turn_text_into_bytes: invalid byte conversion in sentence {} \
character {}: (uint) {} - (hex) {}'.
format(
sentence_id,
character,
integer,
character.encode('utf-8').hex(),
))
continue
sequence.append('{{{}$emotion_neutral${}}}'.format(
integer, speaker))
i += 2
if sequence[-1][1:].split('$')[0] not in ['33', '46', '63']:
sequence.append(
'{{46$emotion_neutral${}}}'.format(speaker))
meta_lines.append('{}\t{}\n'.format(sentence_id,
' '.join(sequence)))
with open(output_meta_file_path, 'w') as out_file:
out_file.writelines(meta_lines)

View File

@@ -26,9 +26,10 @@ def _clean_text(text, cleaner_names):
def get_fpdict(config):
# eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
en_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {en_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
a_sy = '{ga$tone5$s_begin$word_begin$emotion_neutral$F7} {a_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
e_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {e_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
default_sp = config['linguistic_unit']['speaker_list'].split(',')[0]
en_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{en_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
a_sy = f'{{ga$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{a_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
e_sy = f'{{ge$tone5$s_begin$word_begin$emotion_neutral${default_sp}}} {{e_c$tone5$s_end$word_end$emotion_neutral${default_sp}}} {{#3$tone_none$s_none$word_none$emotion_neutral${default_sp}}}' # NOQA: E501
ling_unit = KanTtsLinguisticUnit(config)
en_lings = ling_unit.encode_symbol_sequence(en_sy)
@@ -39,7 +40,7 @@ def get_fpdict(config):
a_ling = np.stack(a_lings, axis=1)[:3, :4]
e_ling = np.stack(e_lings, axis=1)[:3, :4]
fp_dict = {1: a_ling, 2: en_ling, 3: e_ling}
fp_dict = {1: en_ling, 2: a_ling, 3: e_ling}
return fp_dict
@@ -92,12 +93,18 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
self.build()
def using_byte(self):
return 'byte_index' in self._lfeat_type_list
def get_unit_size(self):
ling_unit_size = {}
ling_unit_size['sy'] = len(self.sy)
ling_unit_size['tone'] = len(self.tone)
ling_unit_size['syllable_flag'] = len(self.syllable_flag)
ling_unit_size['word_segment'] = len(self.word_segment)
if self.using_byte():
ling_unit_size['byte_index'] = len(self.byte_index)
else:
ling_unit_size['sy'] = len(self.sy)
ling_unit_size['tone'] = len(self.tone)
ling_unit_size['syllable_flag'] = len(self.syllable_flag)
ling_unit_size['word_segment'] = len(self.word_segment)
if 'emo_category' in self._lfeat_type_list:
ling_unit_size['emotion'] = len(self.emo_category)
@@ -107,77 +114,96 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
return ling_unit_size
def build(self):
self._sub_unit_dim = {}
self._sub_unit_pad = {}
# sy sub-unit
_characters = ''
if self.using_byte():
# Export all byte indices:
self.byte_index = ['@' + str(idx) for idx in range(256)] + [
self._pad,
self._eos,
]
if self.has_mask:
self.byte_index.append(self._mask)
self._byte_index_to_id = {
s: i
for i, s in enumerate(self.byte_index)
}
self._id_to_byte_index = {
i: s
for i, s in enumerate(self.byte_index)
}
self._sub_unit_dim['byte_index'] = len(self.byte_index)
self._sub_unit_pad['byte_index'] = self._byte_index_to_id['_']
else:
# sy sub-unit
_characters = ''
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
_arpabet = ['@' + s for s in self.lang_phones]
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
_arpabet = ['@' + s for s in self.lang_phones]
# Export all symbols:
self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
if self.has_mask:
self.sy.append(self._mask)
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
self._sub_unit_dim['sy'] = len(self.sy)
self._sub_unit_pad['sy'] = self._sy_to_id['_']
# Export all symbols:
self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
if self.has_mask:
self.sy.append(self._mask)
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
self._sub_unit_dim['sy'] = len(self.sy)
self._sub_unit_pad['sy'] = self._sy_to_id['_']
# tone sub-unit
_characters = ''
# tone sub-unit
_characters = ''
# Export all tones:
self.tone = (
list(_characters) + self.lang_tones + [self._pad, self._eos])
if self.has_mask:
self.tone.append(self._mask)
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
self._sub_unit_dim['tone'] = len(self.tone)
self._sub_unit_pad['tone'] = self._tone_to_id['_']
# Export all tones:
self.tone = (
list(_characters) + self.lang_tones + [self._pad, self._eos])
if self.has_mask:
self.tone.append(self._mask)
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
self._sub_unit_dim['tone'] = len(self.tone)
self._sub_unit_pad['tone'] = self._tone_to_id['_']
# syllable flag sub-unit
_characters = ''
# syllable flag sub-unit
_characters = ''
# Export all syllable_flags:
self.syllable_flag = (
list(_characters) + self.lang_syllable_flags
+ [self._pad, self._eos])
if self.has_mask:
self.syllable_flag.append(self._mask)
self._syllable_flag_to_id = {
s: i
for i, s in enumerate(self.syllable_flag)
}
self._id_to_syllable_flag = {
i: s
for i, s in enumerate(self.syllable_flag)
}
self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag)
self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id['_']
# Export all syllable_flags:
self.syllable_flag = (
list(_characters) + self.lang_syllable_flags
+ [self._pad, self._eos])
if self.has_mask:
self.syllable_flag.append(self._mask)
self._syllable_flag_to_id = {
s: i
for i, s in enumerate(self.syllable_flag)
}
self._id_to_syllable_flag = {
i: s
for i, s in enumerate(self.syllable_flag)
}
self._sub_unit_dim['syllable_flag'] = len(self.syllable_flag)
self._sub_unit_pad['syllable_flag'] = self._syllable_flag_to_id[
'_']
# word segment sub-unit
_characters = ''
# word segment sub-unit
_characters = ''
# Export all syllable_flags:
self.word_segment = (
list(_characters) + self.lang_word_segments
+ [self._pad, self._eos])
if self.has_mask:
self.word_segment.append(self._mask)
self._word_segment_to_id = {
s: i
for i, s in enumerate(self.word_segment)
}
self._id_to_word_segment = {
i: s
for i, s in enumerate(self.word_segment)
}
self._sub_unit_dim['word_segment'] = len(self.word_segment)
self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_']
# Export all syllable_flags:
self.word_segment = (
list(_characters) + self.lang_word_segments
+ [self._pad, self._eos])
if self.has_mask:
self.word_segment.append(self._mask)
self._word_segment_to_id = {
s: i
for i, s in enumerate(self.word_segment)
}
self._id_to_word_segment = {
i: s
for i, s in enumerate(self.word_segment)
}
self._sub_unit_dim['word_segment'] = len(self.word_segment)
self._sub_unit_pad['word_segment'] = self._word_segment_to_id['_']
if 'emo_category' in self._lfeat_type_list:
# emotion category sub-unit
@@ -247,6 +273,8 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
sequence_item = sequence[i].tolist()
if lfeat_type == 'sy':
s = self.decode_sy(sequence_item)
elif lfeat_type == 'byte_index':
s = self.decode_byte_index(sequence_item)
elif lfeat_type == 'tone':
s = self.decode_tone(sequence_item)
elif lfeat_type == 'syllable_flag':
@@ -261,7 +289,7 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
raise Exception('Unknown lfeat type: %s' % lfeat_type)
result.append('%s:%s' % (lfeat_type, s))
return result
return
def encode_sub_unit(self, this_lfeat_symbol, lfeat_type):
sequence = []
@@ -276,6 +304,8 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
index = index + 1
sequence = self.encode_text(this_lfeat_symbol_format,
self._cleaner_names)
elif lfeat_type == 'byte_index':
sequence = self.encode_byte_index(this_lfeat_symbol)
elif lfeat_type == 'tone':
sequence = self.encode_tone(this_lfeat_symbol)
elif lfeat_type == 'syllable_flag':
@@ -288,7 +318,6 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
sequence = self.encode_speaker_category(this_lfeat_symbol)
else:
raise Exception('Unknown lfeat type: %s' % lfeat_type)
return sequence
def encode_text(self, text, cleaner_names):
@@ -323,6 +352,20 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
def encode_arpanet(self, text):
return self.encode_sy(['@' + s for s in text.split()])
def encode_byte_index(self, byte_index):
byte_indices = ['@' + s for s in byte_index.strip().split(' ')]
sequence = []
for this_byte_index in byte_indices:
sequence.append(self._byte_index_to_id[this_byte_index])
sequence.append(self._byte_index_to_id['~'])
return sequence
def decode_byte_index(self, id):
s = self._id_to_byte_index[id]
if len(s) > 1 and s[0] == '@':
s = s[1:]
return s
def encode_tone(self, tone):
tones = tone.strip().split(' ')
sequence = []

View File

@@ -123,21 +123,47 @@ class Voice:
with torch.no_grad():
inputs_feat_lst = self.__ling_unit.encode_symbol_sequence(
symbol_seq)
inputs_sy = torch.from_numpy(inputs_feat_lst[0]).long().to(
self.__device)
inputs_tone = torch.from_numpy(inputs_feat_lst[1]).long().to(
self.__device)
inputs_syllable = torch.from_numpy(
inputs_feat_lst[2]).long().to(self.__device)
inputs_ws = torch.from_numpy(inputs_feat_lst[3]).long().to(
self.__device)
inputs_ling = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
dim=-1).unsqueeze(0)
inputs_emo = torch.from_numpy(inputs_feat_lst[4]).long().to(
self.__device).unsqueeze(0)
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to(
self.__device).unsqueeze(0)
inputs_feat_index = 0
if self.__ling_unit.using_byte():
inputs_byte_index = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device))
inputs_ling = torch.stack([inputs_byte_index],
dim=-1).unsqueeze(0)
else:
inputs_sy = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device))
inputs_feat_index = inputs_feat_index + 1
inputs_tone = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device))
inputs_feat_index = inputs_feat_index + 1
inputs_syllable = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device))
inputs_feat_index = inputs_feat_index + 1
inputs_ws = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device))
inputs_ling = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable, inputs_ws],
dim=-1).unsqueeze(0)
inputs_feat_index = inputs_feat_index + 1
inputs_emo = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device).unsqueeze(0))
inputs_feat_index = inputs_feat_index + 1
inputs_spk = (
torch.from_numpy(
inputs_feat_lst[inputs_feat_index]).long().to(
self.__device).unsqueeze(0))
inputs_len = (torch.zeros(1).to(self.__device).long()
+ inputs_emo.size(1) - 1) # minus 1 for "~"
res = self.__am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
@@ -148,9 +174,19 @@ class Voice:
postnet_outputs = postnet_outputs[0, :valid_length, :].cpu()
return postnet_outputs
def __binarize(mel, threshold=0.6):
# vuv binarize
res_mel = mel.clone()
index = torch.where(mel[:, -1] < threshold)[0]
res_mel[:, -1] = 1.0
res_mel[:, -1][index] = 0.0
return res_mel
def __vocoder_forward(self, melspec):
with torch.no_grad():
x = melspec.to(self.__device)
if self.__voc_model.nsf_enable:
x = self.__binarize(x)
x = x.transpose(1, 0).unsqueeze(0)
y = self.__voc_model(x)
if hasattr(self.__voc_model, 'pqmf'):

View File

@@ -1,3 +1,4 @@
bitstring
easyasr>=0.0.2
espnet==202204
funasr>=0.1.4
@@ -38,6 +39,6 @@ speechbrain>=0.5
torchaudio
tqdm
traitlets>=5.3.0
ttsfrd>=0.0.3
ttsfrd>=0.1.1
unidecode
wcwidth>=0.2.5

View File

@@ -27,12 +27,33 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase,
self.task = Tasks.text_to_speech
self.zhcn_text = '今天北京天气怎么样'
self.en_text = 'How is the weather in Beijing?'
self.kokr_text = '오늘날씨가어때요'
self.ru_text = 'Какая сегодня погода?'
self.test_model_name = [
'pretrain_16k', 'pretrain_24k', 'zhitian_emo', 'zhizhe_emo',
'zhiyan_emo', 'zhibei_emo', 'zhcn_16k', 'luca', 'luna', 'andy',
'annie', 'engb_16k', 'enus_16k'
'chuangirl', 'jiajia', 'xiaoda', 'kyong', 'masha', 'pretrain_16k',
'pretrain_24k', 'zhitian_emo', 'zhizhe_emo', 'zhiyan_emo',
'zhibei_emo', 'zhcn_16k', 'luca', 'luna', 'andy', 'annie',
'engb_16k', 'enus_16k'
]
self.test_models = [{
'model':
'speech_tts/speech_sambert-hifigan_tts_chuangirl_Sichuan_16k',
'text': self.zhcn_text
}, {
'model':
'speech_tts/speech_sambert-hifigan_tts_jiajia_Cantonese_16k',
'text': self.zhcn_text
}, {
'model':
'speech_tts/speech_sambert-hifigan_tts_xiaoda_WuuShanghai_16k',
'text': self.zhcn_text
}, {
'model': 'speech_tts/speech_sambert-hifigan_tts_kyong_Korean_16k',
'text': self.kokr_text
}, {
'model': 'speech_tts/speech_sambert-hifigan_tts_masha_Russian_16k',
'text': self.ru_text
}, {
'model':
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k',
'text': self.zhcn_text