[to #42322933] support kantts infer and finetune

Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11111331#tab=detail
This commit is contained in:
jiaqi.sjq
2022-12-20 10:45:34 +08:00
parent b992bf278c
commit 8896087034
79 changed files with 10595 additions and 1793 deletions

View File

@@ -363,6 +363,7 @@ class Trainers(object):
# audio trainers
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
speech_kantts_trainer = 'speech-kantts-trainer'
class Preprocessors(object):
@@ -429,6 +430,7 @@ class Preprocessors(object):
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
wav_to_lists = 'wav-to-lists'
wav_to_scp = 'wav-to-scp'
kantts_data_preprocessor = 'kantts-data-preprocessor'
# multi-modal preprocessor
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'

View File

@@ -7,12 +7,8 @@ if TYPE_CHECKING:
from .sambert_hifi import SambertHifigan
else:
_import_structure = {
'sambert_hifi': ['SambertHifigan'],
}
_import_structure = {'sambert_hifi': ['SambertHifigan']}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],

View File

@@ -0,0 +1,8 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .datasets.dataset import get_am_datasets, get_voc_datasets
from .models import model_builder
from .models.hifigan.hifigan import Generator
from .train.loss import criterion_builder
from .train.trainer import GAN_Trainer, Sambert_Trainer
from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit

View File

@@ -0,0 +1,36 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
from scipy.io import wavfile
DATA_TYPE_DICT = {
'txt': {
'load_func': np.loadtxt,
'desc': 'plain txt file or readable by np.loadtxt',
},
'wav': {
'load_func': lambda x: wavfile.read(x)[1],
'desc': 'wav file or readable by soundfile.read',
},
'npy': {
'load_func': np.load,
'desc': 'any .npy format file',
},
# PCM data type can be loaded by binary format
'bin_f32': {
'load_func': lambda x: np.fromfile(x, dtype=np.float32),
'desc': 'binary file with float32 format',
},
'bin_f64': {
'load_func': lambda x: np.fromfile(x, dtype=np.float64),
'desc': 'binary file with float64 format',
},
'bin_i32': {
'load_func': lambda x: np.fromfile(x, dtype=np.int32),
'desc': 'binary file with int32 format',
},
'bin_i16': {
'load_func': lambda x: np.fromfile(x, dtype=np.int16),
'desc': 'binary file with int16 format',
},
}

View File

@@ -0,0 +1,989 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import functools
import glob
import math
import os
import random
from multiprocessing import Manager
import librosa
import numpy as np
import torch
from scipy.stats import betabinom
from tqdm import tqdm
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import (
KanTtsLinguisticUnit, emotion_types)
from modelscope.utils.logger import get_logger
DATASET_RANDOM_SEED = 1234
logging = get_logger()
@functools.lru_cache(maxsize=256)
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
P = phoneme_count
M = mel_count
x = np.arange(0, P)
mel_text_probs = []
for i in range(1, M + 1):
a, b = scaling * i, scaling * (M + 1 - i)
rv = betabinom(P, a, b)
mel_i_prob = rv.pmf(x)
mel_text_probs.append(mel_i_prob)
return torch.tensor(np.array(mel_text_probs))
class Padder(object):
def __init__(self):
super(Padder, self).__init__()
pass
def _pad1D(self, x, length, pad):
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
def _pad2D(self, x, length, pad):
return np.pad(
x, [(0, length - x.shape[0]), (0, 0)],
mode='constant',
constant_values=pad)
def _pad_durations(self, duration, max_in_len, max_out_len):
framenum = np.sum(duration)
symbolnum = duration.shape[0]
if framenum < max_out_len:
padframenum = max_out_len - framenum
duration = np.insert(
duration, symbolnum, values=padframenum, axis=0)
duration = np.insert(
duration,
symbolnum + 1,
values=[0] * (max_in_len - symbolnum - 1),
axis=0,
)
else:
if symbolnum < max_in_len:
duration = np.insert(
duration,
symbolnum,
values=[0] * (max_in_len - symbolnum),
axis=0)
return duration
def _round_up(self, x, multiple):
remainder = x % multiple
return x if remainder == 0 else x + multiple - remainder
def _prepare_scalar_inputs(self, inputs, max_len, pad):
return torch.from_numpy(
np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
def _prepare_targets(self, targets, max_len, pad):
return torch.from_numpy(
np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
def _prepare_durations(self, durations, max_in_len, max_out_len):
return torch.from_numpy(
np.stack([
self._pad_durations(t, max_in_len, max_out_len)
for t in durations
])).long()
class KanttsDataset(torch.utils.data.Dataset):
def __init__(
self,
metafile,
root_dir,
):
self.meta = []
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error('meta file not found: {}'.format(meta_file))
raise ValueError(
'[Dataset] meta file: {} not found'.format(meta_file))
if not os.path.exists(data_dir):
logging.error('data directory not found: {}'.format(data_dir))
raise ValueError(
'[Dataset] data dir: {} not found'.format(data_dir))
self.meta.extend(self.load_meta(meta_file, data_dir))
def load_meta(self, meta_file, data_dir):
pass
class VocDataset(KanttsDataset):
"""
provide (mel, audio) data pair
"""
def __init__(
self,
metafile,
root_dir,
config,
):
self.config = config
self.sampling_rate = config['audio_config']['sampling_rate']
self.n_fft = config['audio_config']['n_fft']
self.hop_length = config['audio_config']['hop_length']
self.batch_max_steps = config['batch_max_steps']
self.batch_max_frames = self.batch_max_steps // self.hop_length
self.aux_context_window = 0
self.start_offset = self.aux_context_window
self.end_offset = -(self.batch_max_frames + self.aux_context_window)
self.nsf_enable = (
config['Model']['Generator']['params'].get('nsf_params', None)
is not None)
super().__init__(metafile, root_dir)
# Load from training data directory
if len(self.meta) == 0 and isinstance(root_dir, str):
wav_dir = os.path.join(root_dir, 'wav')
mel_dir = os.path.join(root_dir, 'mel')
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError('wav or mel directory not found')
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
elif len(self.meta) == 0 and isinstance(root_dir, list):
for d in root_dir:
wav_dir = os.path.join(d, 'wav')
mel_dir = os.path.join(d, 'mel')
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError('wav or mel directory not found')
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
self.allow_cache = config['allow_cache']
if self.allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
@staticmethod
def gen_metafile(wav_dir, out_dir, split_ratio=0.98):
wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
frame_f0_dir = os.path.join(out_dir, 'frame_f0')
frame_uv_dir = os.path.join(out_dir, 'frame_uv')
mel_dir = os.path.join(out_dir, 'mel')
random.seed(DATASET_RANDOM_SEED)
random.shuffle(wav_files)
num_train = int(len(wav_files) * split_ratio) - 1
with open(os.path.join(out_dir, 'train.lst'), 'w') as f:
for wav_file in wav_files[:num_train]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (not os.path.exists(
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
f.write('{}\n'.format(index))
with open(os.path.join(out_dir, 'valid.lst'), 'w') as f:
for wav_file in wav_files[num_train:]:
index = os.path.splitext(os.path.basename(wav_file))[0]
if (not os.path.exists(
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
f.write('{}\n'.format(index))
def load_meta(self, metafile, data_dir):
with open(metafile, 'r') as f:
lines = f.readlines()
wav_dir = os.path.join(data_dir, 'wav')
mel_dir = os.path.join(data_dir, 'mel')
frame_f0_dir = os.path.join(data_dir, 'frame_f0')
frame_uv_dir = os.path.join(data_dir, 'frame_uv')
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
raise ValueError('wav or mel directory not found')
items = []
logging.info('Loading metafile...')
for name in tqdm(lines):
name = name.strip()
mel_file = os.path.join(mel_dir, name + '.npy')
wav_file = os.path.join(wav_dir, name + '.wav')
frame_f0_file = os.path.join(frame_f0_dir, name + '.npy')
frame_uv_file = os.path.join(frame_uv_dir, name + '.npy')
items.append((wav_file, mel_file, frame_f0_file, frame_uv_file))
return items
def load_meta_from_dir(self, wav_dir, mel_dir):
wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
items = []
for wav_file in wav_files:
mel_file = os.path.join(mel_dir, os.path.basename(wav_file))
if os.path.exists(mel_file):
items.append((wav_file, mel_file))
return items
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx]
wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0]
mel_data = np.load(mel_file)
if self.nsf_enable:
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data),
axis=1)
# make sure the audio length and feature length are matched
wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect')
wav_data = wav_data[:len(mel_data) * self.hop_length]
assert len(mel_data) * self.hop_length == len(wav_data)
if self.allow_cache:
self.caches[idx] = (wav_data, mel_data)
return (wav_data, mel_data)
def collate_fn(self, batch):
wav_data, mel_data = [item[0]
for item in batch], [item[1] for item in batch]
mel_lengths = [len(mel) for mel in mel_data]
start_frames = np.array([
np.random.randint(self.start_offset, length + self.end_offset)
for length in mel_lengths
])
wav_start = start_frames * self.hop_length
wav_end = wav_start + self.batch_max_steps
# aux window works as padding
mel_start = start_frames - self.aux_context_window
mel_end = mel_start + self.batch_max_frames + self.aux_context_window
wav_batch = [
x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end)
]
mel_batch = [
c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end)
]
# (B, 1, T)
wav_batch = torch.tensor(
np.asarray(wav_batch), dtype=torch.float32).unsqueeze(1)
# (B, C, T)
mel_batch = torch.tensor(
np.asarray(mel_batch), dtype=torch.float32).transpose(2, 1)
return wav_batch, mel_batch
def get_voc_datasets(
config,
root_dir,
split_ratio=0.98,
):
if isinstance(root_dir, str):
root_dir = [root_dir]
train_meta_lst = []
valid_meta_lst = []
for data_dir in root_dir:
train_meta = os.path.join(data_dir, 'train.lst')
valid_meta = os.path.join(data_dir, 'valid.lst')
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
VocDataset.gen_metafile(
os.path.join(data_dir, 'wav'), data_dir, split_ratio)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = VocDataset(
train_meta_lst,
root_dir,
config,
)
valid_dataset = VocDataset(
valid_meta_lst,
root_dir,
config,
)
return train_dataset, valid_dataset
def get_fp_label(aug_ling_txt):
token_lst = aug_ling_txt.split(' ')
emo_lst = [token.strip('{}').split('$')[4] for token in token_lst]
syllable_lst = [token.strip('{}').split('$')[0] for token in token_lst]
# EOS token append
emo_lst.append(emotion_types[0])
syllable_lst.append('EOS')
# According to the original emotion tag, set each token's fp label.
if emo_lst[0] != emotion_types[3]:
emo_lst[0] = emotion_types[0]
emo_lst[1] = emotion_types[0]
for i in range(len(emo_lst) - 2, 1, -1):
if emo_lst[i] != emotion_types[3] and emo_lst[i
- 1] != emotion_types[3]:
emo_lst[i] = emotion_types[0]
elif emo_lst[i] != emotion_types[3] and emo_lst[
i - 1] == emotion_types[3]:
emo_lst[i] = emotion_types[3]
if syllable_lst[i - 2] == 'ga':
emo_lst[i + 1] = emotion_types[1]
elif syllable_lst[i - 2] == 'ge' and syllable_lst[i - 1] == 'en_c':
emo_lst[i + 1] = emotion_types[2]
else:
emo_lst[i + 1] = emotion_types[4]
fp_label = []
for i in range(len(emo_lst)):
if emo_lst[i] == emotion_types[0]:
fp_label.append(0)
elif emo_lst[i] == emotion_types[1]:
fp_label.append(1)
elif emo_lst[i] == emotion_types[2]:
fp_label.append(2)
elif emo_lst[i] == emotion_types[3]:
continue
elif emo_lst[i] == emotion_types[4]:
fp_label.append(3)
else:
pass
return np.array(fp_label)
class AmDataset(KanttsDataset):
"""
provide (ling, emo, speaker, mel) pair
"""
def __init__(
self,
metafile,
root_dir,
config,
lang_dir=None,
allow_cache=False,
):
self.config = config
self.with_duration = True
self.nsf_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
'NSF', False)
self.fp_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
'FP', False)
super().__init__(metafile, root_dir)
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
self.padder = Padder()
self.r = self.config['Model']['KanTtsSAMBERT']['params'][
'outputs_per_step']
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
return self.caches[idx]
(
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
) = self.meta[idx]
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
mel_data = np.load(mel_file)
dur_data = np.load(dur_file) if dur_file is not None else None
f0_data = np.load(f0_file)
energy_data = np.load(energy_file)
# generate fp position label according to fpadd_meta
if self.fp_enable and aug_ling_txt is not None:
fp_label = get_fp_label(aug_ling_txt)
else:
fp_label = None
if self.with_duration:
attn_prior = None
else:
attn_prior = beta_binomial_prior_distribution(
len(ling_data[0]), mel_data.shape[0])
# Concat frame-level f0 and uv to mel_data
if self.nsf_enable:
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data],
axis=1)
if self.allow_cache:
self.caches[idx] = (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
)
return (
ling_data,
mel_data,
dur_data,
f0_data,
energy_data,
attn_prior,
fp_label,
)
def load_meta(self, metafile, data_dir):
with open(metafile, 'r') as f:
lines = f.readlines()
aug_ling_dict = {}
if self.fp_enable:
add_fp_metafile = metafile.replace('fprm', 'fpadd')
with open(add_fp_metafile, 'r') as f:
fpadd_lines = f.readlines()
for line in fpadd_lines:
index, aug_ling_txt = line.split('\t')
aug_ling_dict[index] = aug_ling_txt
mel_dir = os.path.join(data_dir, 'mel')
dur_dir = os.path.join(data_dir, 'duration')
f0_dir = os.path.join(data_dir, 'f0')
energy_dir = os.path.join(data_dir, 'energy')
frame_f0_dir = os.path.join(data_dir, 'frame_f0')
frame_uv_dir = os.path.join(data_dir, 'frame_uv')
self.with_duration = os.path.exists(dur_dir)
items = []
logging.info('Loading metafile...')
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split('\t')
mel_file = os.path.join(mel_dir, index + '.npy')
if self.with_duration:
dur_file = os.path.join(dur_dir, index + '.npy')
else:
dur_file = None
f0_file = os.path.join(f0_dir, index + '.npy')
energy_file = os.path.join(energy_dir, index + '.npy')
frame_f0_file = os.path.join(frame_f0_dir, index + '.npy')
frame_uv_file = os.path.join(frame_uv_dir, index + '.npy')
aug_ling_txt = aug_ling_dict.get(index, None)
if self.fp_enable and aug_ling_txt is None:
logging.warning(f'Missing fpadd meta for {index}')
continue
items.append((
ling_txt,
mel_file,
dur_file,
f0_file,
energy_file,
frame_f0_file,
frame_uv_file,
aug_ling_txt,
))
return items
def load_fpadd_meta(self, metafile):
with open(metafile, 'r') as f:
lines = f.readlines()
items = []
logging.info('Loading fpadd metafile...')
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split('\t')
items.append((ling_txt, ))
return items
@staticmethod
def gen_metafile(
raw_meta_file,
out_dir,
train_meta_file,
valid_meta_file,
badlist=None,
split_ratio=0.98,
):
with open(raw_meta_file, 'r') as f:
lines = f.readlines()
frame_f0_dir = os.path.join(out_dir, 'frame_f0')
frame_uv_dir = os.path.join(out_dir, 'frame_uv')
mel_dir = os.path.join(out_dir, 'mel')
duration_dir = os.path.join(out_dir, 'duration')
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(train_meta_file, 'w') as f:
for line in lines[:num_train]:
index = line.split('\t')[0]
if badlist is not None and index in badlist:
continue
if (not os.path.exists(
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(duration_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
f.write(line)
with open(valid_meta_file, 'w') as f:
for line in lines[num_train:]:
index = line.split('\t')[0]
if badlist is not None and index in badlist:
continue
if (not os.path.exists(
os.path.join(frame_f0_dir, index + '.npy'))
or not os.path.exists(
os.path.join(frame_uv_dir, index + '.npy'))
or not os.path.exists(
os.path.join(duration_dir, index + '.npy'))
or not os.path.exists(
os.path.join(mel_dir, index + '.npy'))):
continue
f.write(line)
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
# pure linguistic info: sy|tone|syllable_flag|word_segment
lfeat_type = self.ling_unit._lfeat_type_list[0]
inputs_sy = self.padder._prepare_scalar_inputs(
[x[0][0] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][2] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][3] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# emotion category
lfeat_type = self.ling_unit._lfeat_type_list[4]
data_dict['input_emotions'] = self.padder._prepare_scalar_inputs(
[x[0][4] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# speaker category
lfeat_type = self.ling_unit._lfeat_type_list[5]
data_dict['input_speakers'] = self.padder._prepare_scalar_inputs(
[x[0][5] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# fp label category
if self.fp_enable:
data_dict['fp_label'] = self.padder._prepare_scalar_inputs(
[x[6] for x in batch],
max_input_length,
0,
).long()
data_dict['input_lings'] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
data_dict['valid_input_lengths'] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”影响duration计算所以把length-1
data_dict['valid_output_lengths'] = torch.as_tensor(
[len(x[1]) for x in batch], dtype=torch.long)
max_output_length = torch.max(data_dict['valid_output_lengths']).item()
max_output_round_length = self.padder._round_up(
max_output_length, self.r)
data_dict['mel_targets'] = self.padder._prepare_targets(
[x[1] for x in batch], max_output_round_length, 0.0)
if self.with_duration:
data_dict['durations'] = self.padder._prepare_durations(
[x[2] for x in batch], max_dur_length, max_output_round_length)
else:
data_dict['durations'] = None
if self.with_duration:
if self.fp_enable:
feats_padding_length = max_dur_length
else:
feats_padding_length = max_input_length
else:
feats_padding_length = max_output_round_length
data_dict['pitch_contours'] = self.padder._prepare_scalar_inputs(
[x[3] for x in batch], feats_padding_length, 0.0).float()
data_dict['energy_contours'] = self.padder._prepare_scalar_inputs(
[x[4] for x in batch], feats_padding_length, 0.0).float()
if self.with_duration:
data_dict['attn_priors'] = None
else:
data_dict['attn_priors'] = torch.zeros(
len(batch), max_output_round_length, max_input_length)
for i in range(len(batch)):
attn_prior = batch[i][5]
data_dict['attn_priors'][
i, :attn_prior.shape[0], :attn_prior.shape[1]] = attn_prior
return data_dict
def get_am_datasets(
metafile,
root_dir,
lang_dir,
config,
allow_cache,
split_ratio=0.98,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
if fp_enable:
am_train_fn = 'am_fprm_train.lst'
am_valid_fn = 'am_fprm_valid.lst'
else:
am_train_fn = 'am_train.lst'
am_valid_fn = 'am_valid.lst'
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, am_train_fn)
valid_meta = os.path.join(data_dir, am_valid_fn)
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
AmDataset.gen_metafile(raw_metafile, data_dir, train_meta,
valid_meta, split_ratio)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = AmDataset(train_meta_lst, root_dir, config, lang_dir,
allow_cache)
valid_dataset = AmDataset(valid_meta_lst, root_dir, config, lang_dir,
allow_cache)
return train_dataset, valid_dataset
class MaskingActor(object):
def __init__(self, mask_ratio=0.15):
super(MaskingActor, self).__init__()
self.mask_ratio = mask_ratio
pass
def _get_random_mask(self, length, p1=0.15):
mask = np.random.uniform(0, 1, length)
index = 0
while index < len(mask):
if mask[index] < p1:
mask[index] = 1
else:
mask[index] = 0
index += 1
return mask
def _input_bert_masking(
self,
sequence_array,
nb_symbol_category,
mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
):
sequence_array_mask = sequence_array.copy()
mask_id = np.where(mask == 1)[0]
mask_len = len(mask_id)
rand = np.arange(mask_len)
np.random.shuffle(rand)
# [MASK]
mask_id_p2 = mask_id[rand[0:int(math.floor(mask_len * p2))]]
if len(mask_id_p2) > 0:
sequence_array_mask[mask_id_p2] = mask_symbol_id
# rand
mask_id_p3 = mask_id[
rand[int(math.floor(mask_len * p2)):int(math.floor(mask_len * p2))
+ int(math.floor(mask_len * p3))]]
if len(mask_id_p3) > 0:
sequence_array_mask[mask_id_p3] = random.randint(
0, nb_symbol_category - 1)
# ori
# do nothing
return sequence_array_mask
class BERTTextDataset(torch.utils.data.Dataset):
"""
provide (ling, ling_sy_masked, bert_mask) pair
"""
def __init__(
self,
config,
metafile,
root_dir,
lang_dir=None,
allow_cache=False,
):
self.meta = []
self.config = config
if not isinstance(metafile, list):
metafile = [metafile]
if not isinstance(root_dir, list):
root_dir = [root_dir]
for meta_file, data_dir in zip(metafile, root_dir):
if not os.path.exists(meta_file):
logging.error('meta file not found: {}'.format(meta_file))
raise ValueError(
'[BERT_Text_Dataset] meta file: {} not found'.format(
meta_file))
if not os.path.exists(data_dir):
logging.error('data dir not found: {}'.format(data_dir))
raise ValueError(
'[BERT_Text_Dataset] data dir: {} not found'.format(
data_dir))
self.meta.extend(self.load_meta(meta_file, data_dir))
self.allow_cache = allow_cache
self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
self.padder = Padder()
self.masking_actor = MaskingActor(
self.config['Model']['KanTtsTextsyBERT']['params']['mask_ratio'])
if allow_cache:
self.manager = Manager()
self.caches = self.manager.list()
self.caches += [() for _ in range(len(self.meta))]
def __len__(self):
return len(self.meta)
def __getitem__(self, idx):
if self.allow_cache and len(self.caches[idx]) != 0:
ling_data = self.caches[idx][0]
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
return (ling_data, ling_sy_masked_data, bert_mask)
ling_txt = self.meta[idx]
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
if self.allow_cache:
self.caches[idx] = (ling_data, )
return (ling_data, ling_sy_masked_data, bert_mask)
def load_meta(self, metafile, data_dir):
with open(metafile, 'r') as f:
lines = f.readlines()
items = []
logging.info('Loading metafile...')
for line in tqdm(lines):
line = line.strip()
index, ling_txt = line.split('\t')
items.append((ling_txt))
return items
@staticmethod
def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98):
with open(raw_meta_file, 'r') as f:
lines = f.readlines()
random.seed(DATASET_RANDOM_SEED)
random.shuffle(lines)
num_train = int(len(lines) * split_ratio) - 1
with open(os.path.join(out_dir, 'bert_train.lst'), 'w') as f:
for line in lines[:num_train]:
f.write(line)
with open(os.path.join(out_dir, 'bert_valid.lst'), 'w') as f:
for line in lines[num_train:]:
f.write(line)
def bert_masking(self, ling_data):
length = len(ling_data[0])
mask = self.masking_actor._get_random_mask(
length, p1=self.masking_actor.mask_ratio)
mask[-1] = 0
# sy_masked
sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0]
ling_sy_masked_data = self.masking_actor._input_bert_masking(
ling_data[0],
self.ling_unit.get_unit_size()['sy'],
sy_mask_symbol_id,
mask,
p2=0.8,
p3=0.1,
p4=0.1,
)
return (mask, ling_sy_masked_data)
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x[0][0]) for x in batch))
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
lfeat_type = self.ling_unit._lfeat_type_list[0]
targets_sy = self.padder._prepare_scalar_inputs(
[x[0][0] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# sy masked
inputs_sy = self.padder._prepare_scalar_inputs(
[x[1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.padder._prepare_scalar_inputs(
[x[0][1] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
[x[0][2] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.padder._prepare_scalar_inputs(
[x[0][3] for x in batch],
max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type],
).long()
data_dict['input_lings'] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
data_dict['valid_input_lengths'] = torch.as_tensor(
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
) # 输入的symbol sequence会在后面拼一个“~”影响duration计算所以把length-1
data_dict['targets'] = targets_sy
data_dict['bert_masks'] = self.padder._prepare_scalar_inputs(
[x[2] for x in batch], max_input_length, 0.0)
return data_dict
def get_bert_text_datasets(
metafile,
root_dir,
config,
allow_cache,
split_ratio=0.98,
):
if not isinstance(root_dir, list):
root_dir = [root_dir]
if not isinstance(metafile, list):
metafile = [metafile]
train_meta_lst = []
valid_meta_lst = []
for raw_metafile, data_dir in zip(metafile, root_dir):
train_meta = os.path.join(data_dir, 'bert_train.lst')
valid_meta = os.path.join(data_dir, 'bert_valid.lst')
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
BERTTextDataset.gen_metafile(raw_metafile, data_dir, split_ratio)
train_meta_lst.append(train_meta)
valid_meta_lst.append(valid_meta)
train_dataset = BERTTextDataset(config, train_meta_lst, root_dir,
allow_cache)
valid_dataset = BERTTextDataset(config, valid_meta_lst, root_dir,
allow_cache)
return train_dataset, valid_dataset

View File

@@ -0,0 +1,158 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
from torch.nn.parallel import DistributedDataParallel
import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \
get_fpdict
from .hifigan import (Generator, MultiPeriodDiscriminator,
MultiScaleDiscriminator, MultiSpecDiscriminator)
from .pqmf import PQMF
from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT
def optimizer_builder(model_params, opt_name, opt_params):
opt_cls = getattr(torch.optim, opt_name)
optimizer = opt_cls(model_params, **opt_params)
return optimizer
def scheduler_builder(optimizer, sche_name, sche_params):
scheduler_cls = getattr(kantts_scheduler, sche_name)
scheduler = scheduler_cls(optimizer, **sche_params)
return scheduler
def hifigan_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['discriminator'] = {}
optimizer['discriminator'] = {}
scheduler['discriminator'] = {}
for model_name in config['Model'].keys():
if model_name == 'Generator':
params = config['Model'][model_name]['params']
model['generator'] = Generator(**params).to(device)
optimizer['generator'] = optimizer_builder(
model['generator'].parameters(),
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
config['Model'][model_name]['optimizer'].get('params', {}),
)
scheduler['generator'] = scheduler_builder(
optimizer['generator'],
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
config['Model'][model_name]['scheduler'].get('params', {}),
)
else:
params = config['Model'][model_name]['params']
model['discriminator'][model_name] = globals()[model_name](
**params).to(device)
optimizer['discriminator'][model_name] = optimizer_builder(
model['discriminator'][model_name].parameters(),
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
config['Model'][model_name]['optimizer'].get('params', {}),
)
scheduler['discriminator'][model_name] = scheduler_builder(
optimizer['discriminator'][model_name],
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
config['Model'][model_name]['scheduler'].get('params', {}),
)
out_channels = config['Model']['Generator']['params']['out_channels']
if out_channels > 1:
model['pqmf'] = PQMF(
subbands=out_channels, **config.get('pqmf', {})).to(device)
# FIXME: pywavelets buffer leads to gradient error in DDP training
# Solution: https://github.com/pytorch/pytorch/issues/22095
if distributed:
model['generator'] = DistributedDataParallel(
model['generator'],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
for model_name in model['discriminator'].keys():
model['discriminator'][model_name] = DistributedDataParallel(
model['discriminator'][model_name],
device_ids=[rank],
output_device=rank,
broadcast_buffers=False,
)
return model, optimizer, scheduler
def sambert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['KanTtsSAMBERT'] = KanTtsSAMBERT(
config['Model']['KanTtsSAMBERT']['params']).to(device)
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
if fp_enable:
fp_dict = {
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
for k, v in get_fpdict(config).items()
}
model['KanTtsSAMBERT'].fp_dict = fp_dict
optimizer['KanTtsSAMBERT'] = optimizer_builder(
model['KanTtsSAMBERT'].parameters(),
config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'),
config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}),
)
scheduler['KanTtsSAMBERT'] = scheduler_builder(
optimizer['KanTtsSAMBERT'],
config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'),
config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}),
)
if distributed:
model['KanTtsSAMBERT'] = DistributedDataParallel(
model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank)
return model, optimizer, scheduler
def sybert_model_builder(config, device, rank, distributed):
model = {}
optimizer = {}
scheduler = {}
model['KanTtsTextsyBERT'] = KanTtsTextsyBERT(
config['Model']['KanTtsTextsyBERT']['params']).to(device)
optimizer['KanTtsTextsyBERT'] = optimizer_builder(
model['KanTtsTextsyBERT'].parameters(),
config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'),
config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}),
)
scheduler['KanTtsTextsyBERT'] = scheduler_builder(
optimizer['KanTtsTextsyBERT'],
config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'),
config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}),
)
if distributed:
model['KanTtsTextsyBERT'] = DistributedDataParallel(
model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank)
return model, optimizer, scheduler
model_dict = {
'hifigan': hifigan_model_builder,
'sambert': sambert_model_builder,
'sybert': sybert_model_builder,
}
def model_builder(config, device='cpu', rank=0, distributed=False):
builder_func = model_dict[config['model_type']]
model, optimizer, scheduler = builder_func(config, device, rank,
distributed)
return model, optimizer, scheduler

View File

@@ -0,0 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .hifigan import (Generator, MultiPeriodDiscriminator,
MultiScaleDiscriminator, MultiSpecDiscriminator)

View File

@@ -0,0 +1,613 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import copy
from distutils.version import LooseVersion
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from pytorch_wavelets import DWT1DForward
from torch.nn.utils import spectral_norm, weight_norm
from modelscope.models.audio.tts.kantts.utils.audio_torch import stft
from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d,
ConvTranspose1d, ResidualBlock, SourceModule)
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
class Generator(torch.nn.Module):
def __init__(
self,
in_channels=80,
out_channels=1,
channels=512,
kernel_size=7,
upsample_scales=(8, 8, 2, 2),
upsample_kernal_sizes=(16, 16, 4, 4),
resblock_kernel_sizes=(3, 7, 11),
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
repeat_upsample=True,
bias=True,
causal=True,
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_weight_norm=True,
nsf_params=None,
):
super(Generator, self).__init__()
# check hyperparameters are valid
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
assert len(upsample_scales) == len(upsample_kernal_sizes)
assert len(resblock_dilations) == len(resblock_kernel_sizes)
self.upsample_scales = upsample_scales
self.repeat_upsample = repeat_upsample
self.num_upsamples = len(upsample_kernal_sizes)
self.num_kernels = len(resblock_kernel_sizes)
self.out_channels = out_channels
self.nsf_enable = nsf_params is not None
self.transpose_upsamples = torch.nn.ModuleList()
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
self.conv_blocks = torch.nn.ModuleList()
conv_cls = CausalConv1d if causal else Conv1d
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
self.conv_pre = conv_cls(
in_channels,
channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2)
for i in range(len(upsample_kernal_sizes)):
self.transpose_upsamples.append(
torch.nn.Sequential(
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
conv_transposed_cls(
channels // (2**i),
channels // (2**(i + 1)),
upsample_kernal_sizes[i],
upsample_scales[i],
padding=(upsample_kernal_sizes[i] - upsample_scales[i])
// 2,
),
))
if repeat_upsample:
self.repeat_upsamples.append(
nn.Sequential(
nn.Upsample(
mode='nearest', scale_factor=upsample_scales[i]),
getattr(torch.nn, nonlinear_activation)(
**nonlinear_activation_params),
conv_cls(
channels // (2**i),
channels // (2**(i + 1)),
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
),
))
for j in range(len(resblock_kernel_sizes)):
self.conv_blocks.append(
ResidualBlock(
channels=channels // (2**(i + 1)),
kernel_size=resblock_kernel_sizes[j],
dilation=resblock_dilations[j],
nonlinear_activation=nonlinear_activation,
nonlinear_activation_params=nonlinear_activation_params,
causal=causal,
))
self.conv_post = conv_cls(
channels // (2**(i + 1)),
out_channels,
kernel_size,
1,
padding=(kernel_size - 1) // 2,
)
if self.nsf_enable:
self.source_module = SourceModule(
nb_harmonics=nsf_params['nb_harmonics'],
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
sampling_rate=nsf_params['sampling_rate'],
)
self.source_downs = nn.ModuleList()
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
for i, u in enumerate(self.downsample_cum_rates[::-1]):
if u == 1:
self.source_downs.append(
Conv1d(1, channels // (2**(i + 1)), 1, 1))
else:
self.source_downs.append(
conv_cls(
1,
channels // (2**(i + 1)),
u * 2,
u,
padding=u // 2,
))
def forward(self, x):
if self.nsf_enable:
mel = x[:, :-2, :]
pitch = x[:, -2:-1, :]
uv = x[:, -1:, :]
excitation = self.source_module(pitch, uv)
else:
mel = x
x = self.conv_pre(mel)
for i in range(self.num_upsamples):
# FIXME: sin function here seems to be causing issues
x = torch.sin(x) + x
rep = self.repeat_upsamples[i](x)
if self.nsf_enable:
# Downsampling the excitation signal
e = self.source_downs[i](excitation)
# augment inputs with the excitation
x = rep + e
else:
# transconv
up = self.transpose_upsamples[i](x)
x = rep + up
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.conv_blocks[i * self.num_kernels + j](x)
else:
xs += self.conv_blocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
print('Removing weight norm...')
for layer in self.transpose_upsamples:
layer[-1].remove_weight_norm()
for layer in self.repeat_upsamples:
layer[-1].remove_weight_norm()
for layer in self.conv_blocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()
if self.nsf_enable:
self.source_module.remove_weight_norm()
for layer in self.source_downs:
layer.remove_weight_norm()
class PeriodDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
period=3,
kernel_sizes=[5, 3],
channels=32,
downsample_scales=[3, 3, 3, 3, 1],
max_downsample_channels=1024,
bias=True,
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_spectral_norm=False,
):
super(PeriodDiscriminator, self).__init__()
self.period = period
norm_f = weight_norm if not use_spectral_norm else spectral_norm
self.convs = nn.ModuleList()
in_chs, out_chs = in_channels, channels
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
in_chs,
out_chs,
(kernel_sizes[0], 1),
(downsample_scale, 1),
padding=((kernel_sizes[0] - 1) // 2, 0),
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = out_chs
out_chs = min(out_chs * 4, max_downsample_channels)
self.conv_post = nn.Conv2d(
out_chs,
out_channels,
(kernel_sizes[1] - 1, 1),
1,
padding=((kernel_sizes[1] - 1) // 2, 0),
)
def forward(self, x):
fmap = []
# 1d to 2d
b, c, t = x.shape
if t % self.period != 0: # pad first
n_pad = self.period - (t % self.period)
x = F.pad(x, (0, n_pad), 'reflect')
t = t + n_pad
x = x.view(b, c, t // self.period, self.period)
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiPeriodDiscriminator(torch.nn.Module):
def __init__(
self,
periods=[2, 3, 5, 7, 11],
discriminator_params={
'in_channels': 1,
'out_channels': 1,
'kernel_sizes': [5, 3],
'channels': 32,
'downsample_scales': [3, 3, 3, 3, 1],
'max_downsample_channels': 1024,
'bias': True,
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
'use_spectral_norm': False,
},
):
super(MultiPeriodDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for period in periods:
params = copy.deepcopy(discriminator_params)
params['period'] = period
self.discriminators += [PeriodDiscriminator(**params)]
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class ScaleDiscriminator(torch.nn.Module):
def __init__(
self,
in_channels=1,
out_channels=1,
kernel_sizes=[15, 41, 5, 3],
channels=128,
max_downsample_channels=1024,
max_groups=16,
bias=True,
downsample_scales=[2, 2, 4, 4, 1],
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
use_spectral_norm=False,
):
super(ScaleDiscriminator, self).__init__()
norm_f = weight_norm if not use_spectral_norm else spectral_norm
assert len(kernel_sizes) == 4
for ks in kernel_sizes:
assert ks % 2 == 1
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_channels,
channels,
kernel_sizes[0],
bias=bias,
padding=(kernel_sizes[0] - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = channels
out_chs = channels
groups = 4
for downsample_scale in downsample_scales:
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[1],
stride=downsample_scale,
padding=(kernel_sizes[1] - 1) // 2,
groups=groups,
bias=bias,
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
in_chs = out_chs
out_chs = min(in_chs * 2, max_downsample_channels)
groups = min(groups * 4, max_groups)
out_chs = min(in_chs * 2, max_downsample_channels)
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv1d(
in_chs,
out_chs,
kernel_size=kernel_sizes[2],
stride=1,
padding=(kernel_sizes[2] - 1) // 2,
bias=bias,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.conv_post = norm_f(
nn.Conv1d(
out_chs,
out_channels,
kernel_size=kernel_sizes[3],
stride=1,
padding=(kernel_sizes[3] - 1) // 2,
bias=bias,
))
def forward(self, x):
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = torch.flatten(x, 1, -1)
return x, fmap
class MultiScaleDiscriminator(torch.nn.Module):
def __init__(
self,
scales=3,
downsample_pooling='DWT',
# follow the official implementation setting
downsample_pooling_params={
'kernel_size': 4,
'stride': 2,
'padding': 2,
},
discriminator_params={
'in_channels': 1,
'out_channels': 1,
'kernel_sizes': [15, 41, 5, 3],
'channels': 128,
'max_downsample_channels': 1024,
'max_groups': 16,
'bias': True,
'downsample_scales': [2, 2, 4, 4, 1],
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
},
follow_official_norm=False,
):
super(MultiScaleDiscriminator, self).__init__()
self.discriminators = torch.nn.ModuleList()
# add discriminators
for i in range(scales):
params = copy.deepcopy(discriminator_params)
if follow_official_norm:
params['use_spectral_norm'] = True if i == 0 else False
self.discriminators += [ScaleDiscriminator(**params)]
if downsample_pooling == 'DWT':
self.meanpools = nn.ModuleList(
[DWT1DForward(wave='db3', J=1),
DWT1DForward(wave='db3', J=1)])
self.aux_convs = nn.ModuleList([
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
])
else:
self.meanpools = nn.ModuleList(
[nn.AvgPool1d(4, 2, padding=2),
nn.AvgPool1d(4, 2, padding=2)])
self.aux_convs = None
def forward(self, y):
y_d_rs = []
fmap_rs = []
for i, d in enumerate(self.discriminators):
if i != 0:
if self.aux_convs is None:
y = self.meanpools[i - 1](y)
else:
yl, yh = self.meanpools[i - 1](y)
y = torch.cat([yl, yh[0]], dim=1)
y = self.aux_convs[i - 1](y)
y = F.leaky_relu(y, 0.1)
y_d_r, fmap_r = d(y)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
return y_d_rs, fmap_rs
class SpecDiscriminator(torch.nn.Module):
def __init__(
self,
channels=32,
init_kernel=15,
kernel_size=11,
stride=2,
use_spectral_norm=False,
fft_size=1024,
shift_size=120,
win_length=600,
window='hann_window',
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
):
super(SpecDiscriminator, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
# fft_size // 2 + 1
norm_f = weight_norm if not use_spectral_norm else spectral_norm
final_kernel = 5
post_conv_kernel = 3
blocks = 3
self.convs = nn.ModuleList()
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
fft_size // 2 + 1,
channels,
(init_kernel, 1),
(1, 1),
padding=(init_kernel - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
for i in range(blocks):
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(kernel_size, 1),
(stride, 1),
padding=(kernel_size - 1) // 2,
)),
getattr(
torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.convs.append(
torch.nn.Sequential(
norm_f(
nn.Conv2d(
channels,
channels,
(final_kernel, 1),
(1, 1),
padding=(final_kernel - 1) // 2,
)),
getattr(torch.nn,
nonlinear_activation)(**nonlinear_activation_params),
))
self.conv_post = norm_f(
nn.Conv2d(
channels,
1,
(post_conv_kernel, 1),
(1, 1),
padding=((post_conv_kernel - 1) // 2, 0),
))
self.register_buffer('window', getattr(torch, window)(win_length))
def forward(self, wav):
with torch.no_grad():
wav = torch.squeeze(wav, 1)
x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length,
self.window)
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
fmap = []
for layer in self.convs:
x = layer(x)
fmap.append(x)
x = self.conv_post(x)
fmap.append(x)
x = x.squeeze(-1)
return x, fmap
class MultiSpecDiscriminator(torch.nn.Module):
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
discriminator_params={
'channels': 15,
'init_kernel': 1,
'kernel_sizes': 11,
'stride': 2,
'use_spectral_norm': False,
'window': 'hann_window',
'nonlinear_activation': 'LeakyReLU',
'nonlinear_activation_params': {
'negative_slope': 0.1
},
},
):
super(MultiSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList()
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes,
win_lengths):
params = copy.deepcopy(discriminator_params)
params['fft_size'] = fft_size
params['shift_size'] = hop_size
params['win_length'] = win_length
self.discriminators += [SpecDiscriminator(**params)]
def forward(self, y):
y_d = []
fmap = []
for i, d in enumerate(self.discriminators):
x, x_map = d(y)
y_d.append(x)
fmap.append(x_map)
return y_d, fmap

View File

@@ -0,0 +1,288 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions.normal import Normal
from torch.distributions.uniform import Uniform
from torch.nn.utils import remove_weight_norm, weight_norm
from modelscope.models.audio.tts.kantts.models.utils import init_weights
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class Conv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
):
super(Conv1d, self).__init__()
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
))
self.conv1d.apply(init_weights)
def forward(self, x):
x = self.conv1d(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class CausalConv1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros',
):
super(CausalConv1d, self).__init__()
self.pad = (kernel_size - 1) * dilation
self.conv1d = weight_norm(
nn.Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode,
))
self.conv1d.apply(init_weights)
def forward(self, x): # bdt
x = F.pad(
x, (self.pad, 0, 0, 0, 0, 0), 'constant'
) # described starting from the last dimension and moving forward.
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
x = self.conv1d(x)[:, :, :x.size(2)]
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class ConvTranspose1d(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
super(ConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=padding,
output_padding=0,
))
self.deconv.apply(init_weights)
def forward(self, x):
return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
# FIXME: HACK to get shape right
class CausalConvTranspose1d(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""
def __init__(
self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
):
"""Initialize CausalConvTranspose1d module."""
super(CausalConvTranspose1d, self).__init__()
self.deconv = weight_norm(
nn.ConvTranspose1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
output_padding=0,
))
self.stride = stride
self.deconv.apply(init_weights)
self.pad = kernel_size - stride
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return self.deconv(x)[:, :, :-self.pad]
# return self.deconv(x)
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
class ResidualBlock(torch.nn.Module):
def __init__(
self,
channels,
kernel_size=3,
dilation=(1, 3, 5),
nonlinear_activation='LeakyReLU',
nonlinear_activation_params={'negative_slope': 0.1},
causal=False,
):
super(ResidualBlock, self).__init__()
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
conv_cls = CausalConv1d if causal else Conv1d
self.convs1 = nn.ModuleList([
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=dilation[i],
padding=get_padding(kernel_size, dilation[i]),
) for i in range(len(dilation))
])
self.convs2 = nn.ModuleList([
conv_cls(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
) for i in range(len(dilation))
])
self.activation = getattr(
torch.nn, nonlinear_activation)(**nonlinear_activation_params)
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = self.activation(x)
xt = c1(xt)
xt = self.activation(xt)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class SourceModule(torch.nn.Module):
def __init__(self,
nb_harmonics,
upsample_ratio,
sampling_rate,
alpha=0.1,
sigma=0.003):
super(SourceModule, self).__init__()
self.nb_harmonics = nb_harmonics
self.upsample_ratio = upsample_ratio
self.sampling_rate = sampling_rate
self.alpha = alpha
self.sigma = sigma
self.ffn = nn.Sequential(
weight_norm(
nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
nn.Tanh(),
)
def forward(self, pitch, uv):
"""
:param pitch: [B, 1, frame_len], Hz
:param uv: [B, 1, frame_len] vuv flag
:return: [B, 1, sample_len]
"""
with torch.no_grad():
pitch_samples = F.interpolate(
pitch, scale_factor=(self.upsample_ratio), mode='nearest')
uv_samples = F.interpolate(
uv, scale_factor=(self.upsample_ratio), mode='nearest')
F_mat = torch.zeros(
(pitch_samples.size(0), self.nb_harmonics + 1,
pitch_samples.size(-1))).to(pitch_samples.device)
for i in range(self.nb_harmonics + 1):
F_mat[:, i:i
+ 1, :] = pitch_samples * (i + 1) / self.sampling_rate
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
u_dist = Uniform(low=-np.pi, high=np.pi)
phase_vec = u_dist.sample(
sample_shape=(pitch.size(0), self.nb_harmonics + 1,
1)).to(F_mat.device)
phase_vec[:, 0, :] = 0
n_dist = Normal(loc=0.0, scale=self.sigma)
noise = n_dist.sample(
sample_shape=(
pitch_samples.size(0),
self.nb_harmonics + 1,
pitch_samples.size(-1),
)).to(F_mat.device)
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
e_unvoice = self.alpha / 3 / self.sigma * noise
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
return self.ffn(e)
def remove_weight_norm(self):
remove_weight_norm(self.ffn[0])

View File

@@ -0,0 +1,133 @@
# The implementation is adopted from kan-bayashi's ParallelWaveGAN,
# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN
import numpy as np
import torch
import torch.nn.functional as F
from scipy.signal import kaiser
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
"""Design prototype filter for PQMF.
This method is based on `A Kaiser window approach for the design of prototype
filters of cosine modulated filterbanks`_.
Args:
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
Returns:
ndarray: Impluse response of prototype filter (taps + 1,).
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
https://ieeexplore.ieee.org/abstract/document/681427
"""
# check the arguments are valid
assert taps % 2 == 0, 'The number of taps mush be even number.'
assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.'
# make initial filter
omega_c = np.pi * cutoff_ratio
with np.errstate(invalid='ignore'):
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
np.pi * (np.arange(taps + 1) - 0.5 * taps))
h_i[taps
// 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
# apply kaiser window
w = kaiser(taps + 1, beta)
h = h_i * w
return h
class PQMF(torch.nn.Module):
"""PQMF module.
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
https://ieeexplore.ieee.org/document/258122
"""
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
"""Initilize PQMF module.
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
Args:
subbands (int): The number of subbands.
taps (int): The number of filter taps.
cutoff_ratio (float): Cut-off frequency ratio.
beta (float): Beta coefficient for kaiser window.
"""
super(PQMF, self).__init__()
# build analysis & synthesis filter coefficients
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
h_analysis = np.zeros((subbands, len(h_proto)))
h_synthesis = np.zeros((subbands, len(h_proto)))
for k in range(subbands):
h_analysis[k] = (
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
(np.pi / (2 * subbands)) * # noqa W504
(np.arange(taps + 1) - (taps / 2))
+ (-1)**k * np.pi / 4))
h_synthesis[k] = (
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
(np.pi / (2 * subbands)) * # noqa W504
(np.arange(taps + 1) - (taps / 2))
- (-1)**k * np.pi / 4))
# convert to tensor
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
# register coefficients as beffer
self.register_buffer('analysis_filter', analysis_filter)
self.register_buffer('synthesis_filter', synthesis_filter)
# filter for downsampling & upsampling
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
for k in range(subbands):
updown_filter[k, k, 0] = 1.0
self.register_buffer('updown_filter', updown_filter)
self.subbands = subbands
# keep padding info
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
def analysis(self, x):
"""Analysis with PQMF.
Args:
x (Tensor): Input tensor (B, 1, T).
Returns:
Tensor: Output tensor (B, subbands, T // subbands).
"""
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
return F.conv1d(x, self.updown_filter, stride=self.subbands)
def synthesis(self, x):
"""Synthesis with PQMF.
Args:
x (Tensor): Input tensor (B, subbands, T // subbands).
Returns:
Tensor: Output tensor (B, 1, T).
"""
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
# Not sure this is the correct way, it is better to check again.
x = F.conv_transpose1d(
x, self.updown_filter * self.subbands, stride=self.subbands)
return F.conv1d(self.pad_fn(x), self.synthesis_filter)

View File

@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
@@ -42,7 +41,7 @@ class Prenet(nn.Module):
self.fcs.append(nn.ReLU())
self.fcs.append(nn.Dropout(0.5))
if (out_units):
if out_units:
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
def forward(self, input):
@@ -105,7 +104,7 @@ class MultiHeadSelfAttention(nn.Module):
-1)) # b x l x (n*d)
output = self.dropout(self.fc(output))
if (output.size(-1) == residual.size(-1)):
if output.size(-1) == residual.size(-1):
output = output + residual
return output, attn
@@ -162,16 +161,18 @@ class PositionwiseConvFeedForward(nn.Module):
class FFTBlock(nn.Module):
"""FFT Block"""
def __init__(self,
d_in,
d_model,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0):
def __init__(
self,
d_in,
d_model,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(FFTBlock, self).__init__()
self.slf_attn = MultiHeadSelfAttention(
n_head,
@@ -239,7 +240,7 @@ class MultiHeadPNCAAttention(nn.Module):
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
if (self.x_state_size):
if self.x_state_size:
self.x_k = torch.cat([self.x_k, x_k], dim=1)
self.x_v = torch.cat([self.x_v, x_v], dim=1)
else:
@@ -251,7 +252,7 @@ class MultiHeadPNCAAttention(nn.Module):
return x_q, x_k, x_v
def update_h_state(self, h):
if (self.h_state_size == h.size(1)):
if self.h_state_size == h.size(1):
return None, None
d_head, n_head = self.d_head, self.n_head
@@ -323,16 +324,18 @@ class MultiHeadPNCAAttention(nn.Module):
class PNCABlock(nn.Module):
"""PNCA Block"""
def __init__(self,
d_model,
d_mem,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0):
def __init__(
self,
d_model,
d_mem,
n_head,
d_head,
d_inner,
kernel_size,
dropout,
dropout_attn=0.0,
dropout_relu=0.0,
):
super(PNCABlock, self).__init__()
self.pnca_attn = MultiHeadPNCAAttention(
n_head,

View File

@@ -1,10 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn as nn
import torch.nn.functional as F
from .base import Prenet
from . import Prenet
from .fsmn import FsmnEncoderV2
@@ -22,8 +21,8 @@ class LengthRegulator(nn.Module):
reps_cumsum = torch.cumsum(
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
mult = ((reps_cumsum[:, :, :-1] <= range_)
& (reps_cumsum[:, :, 1:] > range_)) # yapf:disable
mult = (reps_cumsum[:, :, :-1] <= range_) & (
reps_cumsum[:, :, 1:] > range_)
mult = mult.float()
out = torch.matmul(mult, inputs)
@@ -32,7 +31,7 @@ class LengthRegulator(nn.Module):
seq_len = out.size(1)
padding = self.r - int(seq_len) % self.r
if (padding < self.r):
if padding < self.r:
out = F.pad(
out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
out = out.transpose(1, 2)
@@ -51,7 +50,8 @@ class VarRnnARPredictor(nn.Module):
rnn_units,
num_layers=2,
batch_first=True,
bidirectional=False)
bidirectional=False,
)
self.fc = nn.Linear(rnn_units, 1)
def forward(self, inputs, cond, h=None, masks=None):
@@ -89,19 +89,35 @@ class VarRnnARPredictor(nn.Module):
class VarFsmnRnnNARPredictor(nn.Module):
def __init__(self, in_dim, filter_size, fsmn_num_layers, num_memory_units,
ffn_inner_dim, dropout, shift, lstm_units):
def __init__(
self,
in_dim,
filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
lstm_units,
):
super(VarFsmnRnnNARPredictor, self).__init__()
self.fsmn = FsmnEncoderV2(filter_size, fsmn_num_layers, in_dim,
num_memory_units, ffn_inner_dim, dropout,
shift)
self.fsmn = FsmnEncoderV2(
filter_size,
fsmn_num_layers,
in_dim,
num_memory_units,
ffn_inner_dim,
dropout,
shift,
)
self.blstm = nn.LSTM(
num_memory_units,
lstm_units,
num_layers=1,
batch_first=True,
bidirectional=True)
bidirectional=True,
)
self.fc = nn.Linear(2 * lstm_units, 1)
def forward(self, inputs, masks=None):

View File

@@ -0,0 +1,73 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numba as nb
import numpy as np
@nb.jit(nopython=True)
def mas(attn_map, width=1):
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_j = np.arange(max(0, j - width), j + 1)
prev_log = np.array(
[log_p[i - 1, prev_idx] for prev_idx in prev_j])
ind = np.argmax(prev_log)
log_p[i, j] = attn_map[i, j] + prev_log[ind]
prev_ind[i, j] = prev_j[ind]
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True)
def mas_width1(attn_map):
"""mas with hardcoded width=1"""
# assumes mel x text
opt = np.zeros_like(attn_map)
attn_map = np.log(attn_map)
attn_map[0, 1:] = -np.inf
log_p = np.zeros_like(attn_map)
log_p[0, :] = attn_map[0, :]
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
for i in range(1, attn_map.shape[0]):
for j in range(attn_map.shape[1]): # for each text dim
prev_log = log_p[i - 1, j]
prev_j = j
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
prev_log = log_p[i - 1, j - 1]
prev_j = j - 1
log_p[i, j] = attn_map[i, j] + prev_log
prev_ind[i, j] = prev_j
# now backtrack
curr_text_idx = attn_map.shape[1] - 1
for i in range(attn_map.shape[0] - 1, -1, -1):
opt[i, curr_text_idx] = 1
curr_text_idx = prev_ind[i, curr_text_idx]
opt[0, curr_text_idx] = 1
return opt
@nb.jit(nopython=True, parallel=True)
def b_mas(b_attn_map, in_lens, out_lens, width=1):
assert width == 1
attn_out = np.zeros_like(b_attn_map)
for b in nb.prange(b_attn_map.shape[0]):
out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
return attn_out

View File

@@ -0,0 +1,131 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
from torch import nn
class ConvNorm(torch.nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=None,
dilation=1,
bias=True,
w_init_gain='linear',
):
super(ConvNorm, self).__init__()
if padding is None:
assert kernel_size % 2 == 1
padding = int(dilation * (kernel_size - 1) / 2)
self.conv = torch.nn.Conv1d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
bias=bias,
)
torch.nn.init.xavier_uniform_(
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
def forward(self, signal):
conv_signal = self.conv(signal)
return conv_signal
class ConvAttention(torch.nn.Module):
def __init__(
self,
n_mel_channels=80,
n_text_channels=512,
n_att_channels=80,
temperature=1.0,
use_query_proj=True,
):
super(ConvAttention, self).__init__()
self.temperature = temperature
self.att_scaling_factor = np.sqrt(n_att_channels)
self.softmax = torch.nn.Softmax(dim=3)
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
self.use_query_proj = bool(use_query_proj)
self.key_proj = nn.Sequential(
ConvNorm(
n_text_channels,
n_text_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu',
),
torch.nn.ReLU(),
ConvNorm(
n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
)
self.query_proj = nn.Sequential(
ConvNorm(
n_mel_channels,
n_mel_channels * 2,
kernel_size=3,
bias=True,
w_init_gain='relu',
),
torch.nn.ReLU(),
ConvNorm(
n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
torch.nn.ReLU(),
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
)
def forward(self, queries, keys, mask=None, attn_prior=None):
"""Attention mechanism for flowtron parallel
Unlike in Flowtron, we have no restrictions such as causality etc,
since we only need this during training.
Args:
queries (torch.tensor): B x C x T1 tensor
(probably going to be mel data)
keys (torch.tensor): B x C2 x T2 tensor (text data)
mask (torch.tensor): uint8 binary mask for variable length entries
(should be in the T2 domain)
Output:
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
Final dim T2 should sum to 1
"""
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
# Beware can only do this since query_dim = attn_dim = n_mel_channels
if self.use_query_proj:
queries_enc = self.query_proj(queries)
else:
queries_enc = queries
# different ways of computing attn,
# one is isotopic gaussians (per phoneme)
# Simplistic Gaussian Isotopic Attention
# B x n_attn_dims x T1 x T2
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2
# compute log likelihood from a gaussian
attn = -0.0005 * attn.sum(1, keepdim=True)
if attn_prior is not None:
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]
+ 1e-8)
attn_logprob = attn.clone()
if mask is not None:
attn.data.masked_fill_(
mask.unsqueeze(1).unsqueeze(1), -float('inf'))
attn = self.softmax(attn) # Softmax along T2
return attn, attn_logprob

View File

@@ -1,8 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
"""
FSMN Pytorch Version
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -27,7 +23,8 @@ class FeedForwardNet(nn.Module):
d_out,
kernel_size=kernel_size[1],
padding=(kernel_size[1] - 1) // 2,
bias=False)
bias=False,
)
self.dropout = nn.Dropout(dropout)
@@ -63,8 +60,10 @@ class MemoryBlockV2(nn.Module):
x = F.pad(
input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0)
output = self.conv_dw(x.contiguous().transpose(
1, 2)).contiguous().transpose(1, 2)
output = (
self.conv_dw(x.contiguous().transpose(1,
2)).contiguous().transpose(
1, 2))
output += input
output = self.dropout(output)
@@ -76,14 +75,16 @@ class MemoryBlockV2(nn.Module):
class FsmnEncoderV2(nn.Module):
def __init__(self,
filter_size,
fsmn_num_layers,
input_dim,
num_memory_units,
ffn_inner_dim,
dropout=0.0,
shift=0):
def __init__(
self,
filter_size,
fsmn_num_layers,
input_dim,
num_memory_units,
ffn_inner_dim,
dropout=0.0,
shift=0,
):
super(FsmnEncoderV2, self).__init__()
self.filter_size = filter_size
@@ -119,7 +120,7 @@ class FsmnEncoderV2(nn.Module):
context = ffn(x)
memory = memory_block(context, mask)
memory = F.dropout(memory, self.dropout, self.training)
if (memory.size(-1) == x.size(-1)):
if memory.size(-1) == x.size(-1):
memory += x
x = memory

File diff suppressed because it is too large Load Diff

View File

@@ -1,5 +1,4 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
@@ -15,14 +14,16 @@ class SinusoidalPositionEncoder(nn.Module):
self.depth = depth
self.position_enc = nn.Parameter(
self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
requires_grad=False)
requires_grad=False,
)
def forward(self, input):
bz_in, len_in, _ = input.size()
if len_in > self.max_len:
self.max_len = len_in
self.position_enc.data = self.get_sinusoid_encoding_table(
self.max_len, self.depth).unsqueeze(0).to(input.device)
self.position_enc.data = (
self.get_sinusoid_encoding_table(
self.max_len, self.depth).unsqueeze(0).to(input.device))
output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)
@@ -74,8 +75,8 @@ class DurSinusoidalPositionEncoder(nn.Module):
reps_cumsum = torch.cumsum(
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
range_ = torch.arange(max_len).to(durations.device)[None, :, None]
mult = ((reps_cumsum[:, :, :-1] <= range_)
& (reps_cumsum[:, :, 1:] > range_)) # yapf:disable
mult = (reps_cumsum[:, :, :-1] <= range_) & (
reps_cumsum[:, :, 1:] > range_)
mult = mult.float()
offsets = torch.matmul(mult,
reps_cumsum[:,
@@ -88,7 +89,7 @@ class DurSinusoidalPositionEncoder(nn.Module):
seq_len = dur_pos.size(1)
padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
if (padding < self.outputs_per_step):
if padding < self.outputs_per_step:
dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)
position_embedding = dur_pos[:, :, None] / self.inv_timescales[None,

View File

@@ -0,0 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from distutils.version import LooseVersion
import torch
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(mean, std)
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = (
torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
-1).to(lengths.device))
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask

View File

@@ -0,0 +1,774 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
from concurrent.futures import ProcessPoolExecutor
from glob import glob
import numpy as np
import yaml
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence,
trim_silence_with_interval)
from .core.utils import (align_length, average_by_duration, compute_mean,
compute_std, encode_16bits, f0_norm_mean_std,
get_energy, get_pitch, norm_mean_std,
parse_interval_file, volume_normalize)
logging = get_logger()
default_audio_config = {
# Preprocess
'wav_normalize': True,
'trim_silence': True,
'trim_silence_threshold_db': 60,
'preemphasize': False,
# Feature extraction
'sampling_rate': 24000,
'hop_length': 240,
'win_length': 1024,
'n_mels': 80,
'n_fft': 1024,
'fmin': 50.0,
'fmax': 7600.0,
'min_level_db': -100,
'ref_level_db': 20,
'phone_level_feature': True,
'num_workers': 16,
# Normalization
'norm_type': 'mean_std', # 'mean_std', 'global norm'
'max_norm': 1.0,
'symmetric': False,
}
class AudioProcessor:
def __init__(self, config=None):
if not isinstance(config, dict):
logging.warning(
'[AudioProcessor] config is not a dict, fall into default config.'
)
self.config = default_audio_config
else:
self.config = config
for key in self.config:
setattr(self, key, self.config[key])
self.min_wav_length = int(self.config['sampling_rate'] * 0.5)
self.badcase_list = []
self.pcm_dict = {}
self.mel_dict = {}
self.f0_dict = {}
self.uv_dict = {}
self.nccf_dict = {}
self.f0uv_dict = {}
self.energy_dict = {}
self.dur_dict = {}
logging.info('[AudioProcessor] Initialize AudioProcessor.')
logging.info('[AudioProcessor] config params:')
for key in self.config:
logging.info('[AudioProcessor] %s: %s', key, self.config[key])
def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile,
out_cali_duration_dir):
with open(raw_metafile, 'r') as f:
lines = f.readlines()
output_dur_dir = out_cali_duration_dir
os.makedirs(output_dur_dir, exist_ok=True)
for line in lines:
line = line.strip()
index, symbols = line.split('\t')
symbols = [
symbol.strip('{').strip('}').split('$')[0]
for symbol in symbols.strip().split(' ')
]
dur_file = os.path.join(raw_dur_dir, index + '.npy')
phone_file = os.path.join(raw_dur_dir, index + '.phone')
if not os.path.exists(dur_file) or not os.path.exists(phone_file):
logging.warning(
'[AudioProcessor] dur file or phone file not exists: %s',
index)
continue
with open(phone_file, 'r') as f:
phones = f.readlines()
dur = np.load(dur_file)
cali_duration = []
dur_idx = 0
syll_idx = 0
while dur_idx < len(dur) and syll_idx < len(symbols):
if phones[dur_idx].strip() == 'sil':
dur_idx += 1
continue
if phones[dur_idx].strip(
) == 'sp' and symbols[syll_idx][0] != '#':
dur_idx += 1
continue
if symbols[syll_idx] in ['ga', 'go', 'ge']:
cali_duration.append(0)
syll_idx += 1
# print("NONE", symbols[syll_idx], 0)
continue
if symbols[syll_idx][0] == '#':
if phones[dur_idx].strip() != 'sp':
cali_duration.append(0)
# print("NONE", symbols[syll_idx], 0)
syll_idx += 1
continue
else:
cali_duration.append(dur[dur_idx])
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
dur_idx += 1
syll_idx += 1
continue
# A corresponding phone is found
cali_duration.append(dur[dur_idx])
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
dur_idx += 1
syll_idx += 1
# Add #4 phone duration
cali_duration.append(0)
if len(cali_duration) != len(symbols):
logging.error('[Duration Calibrating] Syllable duration {}\
is not equal to the number of symbols {}, index: {}'.
format(len(cali_duration), len(symbols), index))
continue
# Align with mel frames
durs = np.array(cali_duration)
if len(self.mel_dict) > 0:
pair_mel = self.mel_dict.get(index, None)
if pair_mel is None:
logging.warning(
'[AudioProcessor] Interval file %s has no corresponding mel',
index,
)
continue
mel_frames = pair_mel.shape[0]
dur_frames = np.sum(durs)
if np.sum(durs) > mel_frames:
durs[-2] -= dur_frames - mel_frames
elif np.sum(durs) < mel_frames:
durs[-2] += mel_frames - np.sum(durs)
if durs[-2] < 0:
logging.error(
'[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s',
index,
durs[-2],
)
self.badcase_list.append(index)
continue
self.dur_dict[index] = durs
np.save(
os.path.join(output_dur_dir, index + '.npy'),
self.dur_dict[index])
def amp_normalize(self, src_wav_dir, out_wav_dir):
if self.wav_normalize:
logging.info('[AudioProcessor] Amplitude normalization started')
os.makedirs(out_wav_dir, exist_ok=True)
res = volume_normalize(src_wav_dir, out_wav_dir)
logging.info('[AudioProcessor] Amplitude normalization finished')
return res
else:
logging.info('[AudioProcessor] No amplitude normalization')
os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True)
return True
def get_pcm_dict(self, src_wav_dir):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
if len(self.pcm_dict) > 0:
return self.pcm_dict
logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_path in wav_list:
future = executor.submit(load_wav, wav_path,
self.sampling_rate)
future.add_done_callback(lambda p: progress.update())
wav_name = os.path.splitext(os.path.basename(wav_path))[0]
futures.append((future, wav_name))
for future, wav_name in futures:
pcm = future.result()
if len(pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_name)
self.badcase_list.append(wav_name)
continue
self.pcm_dict[wav_name] = pcm
return self.pcm_dict
def trim_silence_wav(self, src_wav_dir, out_wav_dir=None):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
logging.info('[AudioProcessor] Trim silence started')
if out_wav_dir is None:
out_wav_dir = src_wav_dir
else:
os.makedirs(out_wav_dir, exist_ok=True)
pcm_dict = self.get_pcm_dict(src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
trim_silence,
pcm_data,
self.trim_silence_threshold_db,
self.hop_length,
self.win_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in tqdm(futures):
pcm = future.result()
if len(pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_basename)
self.badcase_list.append(wav_basename)
self.pcm_dict.pop(wav_basename)
continue
self.pcm_dict[wav_basename] = pcm
save_wav(
self.pcm_dict[wav_basename],
os.path.join(out_wav_dir, wav_basename + '.wav'),
self.sampling_rate,
)
logging.info('[AudioProcessor] Trim silence finished')
return True
def trim_silence_wav_with_interval(self,
src_wav_dir,
dur_dir,
out_wav_dir=None):
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
logging.info('[AudioProcessor] Trim silence with interval started')
if out_wav_dir is None:
out_wav_dir = src_wav_dir
else:
os.makedirs(out_wav_dir, exist_ok=True)
pcm_dict = self.get_pcm_dict(src_wav_dir)
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
trim_silence_with_interval,
pcm_data,
self.dur_dict.get(wav_basename, None),
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in tqdm(futures):
trimed_pcm = future.result()
if trimed_pcm is None:
continue
if len(trimed_pcm) < self.min_wav_length:
logging.warning('[AudioProcessor] %s is too short, skip',
wav_basename)
self.badcase_list.append(wav_basename)
self.pcm_dict.pop(wav_basename)
continue
self.pcm_dict[wav_basename] = trimed_pcm
save_wav(
self.pcm_dict[wav_basename],
os.path.join(out_wav_dir, wav_basename + '.wav'),
self.sampling_rate,
)
logging.info('[AudioProcessor] Trim silence finished')
return True
def mel_extract(self, src_wav_dir, out_feature_dir):
os.makedirs(out_feature_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
logging.info('[AudioProcessor] Melspec extraction started')
# Get global normed mel spec
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
melspectrogram,
pcm_data,
self.sampling_rate,
self.n_fft,
self.hop_length,
self.win_length,
self.n_mels,
self.max_norm,
self.min_level_db,
self.ref_level_db,
self.fmin,
self.fmax,
self.symmetric,
self.preemphasize,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Melspec extraction failed for %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
melspec = result
self.mel_dict[wav_basename] = melspec
logging.info('[AudioProcessor] Melspec extraction finished')
# FIXME: is this step necessary?
# Do mean std norm on global-normed melspec
logging.info('Melspec statistic proceeding...')
mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels)
mel_std = compute_std(
list(self.mel_dict.values()), mel_mean, dims=self.n_mels)
logging.info('Melspec statistic done')
np.savetxt(
os.path.join(out_feature_dir, 'mel_mean.txt'),
mel_mean,
fmt='%.6f')
np.savetxt(
os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f')
logging.info(
'[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format(
os.path.join(out_feature_dir, 'mel_mean.txt'),
os.path.join(out_feature_dir, 'mel_std.txt'),
))
logging.info('[AudioProcessor] Melspec mean std norm is proceeding...')
for wav_basename in self.mel_dict:
melspec = self.mel_dict[wav_basename]
norm_melspec = norm_mean_std(melspec, mel_mean, mel_std)
np.save(
os.path.join(out_feature_dir, wav_basename + '.npy'),
norm_melspec)
logging.info('[AudioProcessor] Melspec normalization finished')
logging.info('[AudioProcessor] Normed Melspec saved to %s',
out_feature_dir)
return True
def duration_generate(self, src_interval_dir, out_feature_dir):
os.makedirs(out_feature_dir, exist_ok=True)
interval_list = glob(os.path.join(src_interval_dir, '*.interval'))
logging.info('[AudioProcessor] Duration generation started')
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(interval_list)) as progress:
futures = []
for interval_file_path in interval_list:
future = executor.submit(
parse_interval_file,
interval_file_path,
self.sampling_rate,
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future,
os.path.splitext(
os.path.basename(interval_file_path))[0]))
logging.info(
'[AudioProcessor] Duration align with mel is proceeding...')
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Duration generate failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
durs, phone_list = result
# Algin length with melspec
if len(self.mel_dict) > 0:
pair_mel = self.mel_dict.get(wav_basename, None)
if pair_mel is None:
logging.warning(
'[AudioProcessor] Interval file %s has no corresponding mel',
wav_basename,
)
continue
mel_frames = pair_mel.shape[0]
dur_frames = np.sum(durs)
if np.sum(durs) > mel_frames:
durs[-1] -= dur_frames - mel_frames
elif np.sum(durs) < mel_frames:
durs[-1] += mel_frames - np.sum(durs)
if durs[-1] < 0:
logging.error(
'[AudioProcessor] Duration align failed for %s, mismatch frames %s',
wav_basename,
durs[-1],
)
self.badcase_list.append(wav_basename)
continue
self.dur_dict[wav_basename] = durs
np.save(
os.path.join(out_feature_dir, wav_basename + '.npy'),
durs)
with open(
os.path.join(out_feature_dir,
wav_basename + '.phone'), 'w') as f:
f.write('\n'.join(phone_list))
logging.info('[AudioProcessor] Duration generate finished')
return True
def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir,
out_frame_uv_dir):
os.makedirs(out_f0_dir, exist_ok=True)
os.makedirs(out_frame_f0_dir, exist_ok=True)
os.makedirs(out_frame_uv_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
mel_dict = self.mel_dict
logging.info('[AudioProcessor] Pitch extraction started')
# Get raw pitch
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(
get_pitch,
encode_16bits(pcm_data),
self.sampling_rate,
self.hop_length,
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
logging.info(
'[AudioProcessor] Pitch align with mel is proceeding...')
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Pitch extraction failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
f0, uv, f0uv = result
if len(mel_dict) > 0:
f0 = align_length(f0, mel_dict.get(wav_basename, None))
uv = align_length(uv, mel_dict.get(wav_basename, None))
f0uv = align_length(f0uv,
mel_dict.get(wav_basename, None))
if f0 is None or uv is None or f0uv is None:
logging.warning(
'[AudioProcessor] Pitch length mismatch with mel in %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
continue
self.f0_dict[wav_basename] = f0
self.uv_dict[wav_basename] = uv
self.f0uv_dict[wav_basename] = f0uv
# Normalize f0
logging.info('[AudioProcessor] Pitch normalization is proceeding...')
f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1)
f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1)
np.savetxt(
os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f')
np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f')
logging.info(
'[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format(
os.path.join(out_f0_dir, 'f0_mean.txt'),
os.path.join(out_f0_dir, 'f0_std.txt'),
))
logging.info('[AudioProcessor] Pitch mean std norm is proceeding...')
for wav_basename in self.f0uv_dict:
f0 = self.f0uv_dict[wav_basename]
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
self.f0uv_dict[wav_basename] = norm_f0
for wav_basename in self.f0_dict:
f0 = self.f0_dict[wav_basename]
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
self.f0_dict[wav_basename] = norm_f0
# save frame f0 to a specific dir
for wav_basename in self.f0_dict:
np.save(
os.path.join(out_frame_f0_dir, wav_basename + '.npy'),
self.f0_dict[wav_basename].reshape(-1),
)
for wav_basename in self.uv_dict:
np.save(
os.path.join(out_frame_uv_dir, wav_basename + '.npy'),
self.uv_dict[wav_basename].reshape(-1),
)
# phone level average
# if there is no duration then save the frame-level f0
if self.phone_level_feature and len(self.dur_dict) > 0:
logging.info(
'[AudioProcessor] Pitch turn to phone-level is proceeding...')
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(self.f0uv_dict)) as progress:
futures = []
for wav_basename in self.f0uv_dict:
future = executor.submit(
average_by_duration,
self.f0uv_dict.get(wav_basename, None),
self.dur_dict.get(wav_basename, None),
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Pitch extraction failed in phone level avg for: %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
avg_f0 = result
self.f0uv_dict[wav_basename] = avg_f0
for wav_basename in self.f0uv_dict:
np.save(
os.path.join(out_f0_dir, wav_basename + '.npy'),
self.f0uv_dict[wav_basename].reshape(-1),
)
logging.info('[AudioProcessor] Pitch normalization finished')
logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir)
logging.info('[AudioProcessor] Pitch extraction finished')
return True
def energy_extract(self, src_wav_dir, out_energy_dir,
out_frame_energy_dir):
os.makedirs(out_energy_dir, exist_ok=True)
os.makedirs(out_frame_energy_dir, exist_ok=True)
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
pcm_dict = self.get_pcm_dict(src_wav_dir)
mel_dict = self.mel_dict
logging.info('[AudioProcessor] Energy extraction started')
# Get raw energy
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(wav_list)) as progress:
futures = []
for wav_basename, pcm_data in pcm_dict.items():
future = executor.submit(get_energy, pcm_data, self.hop_length,
self.win_length, self.n_fft)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Energy extraction failed for %s',
wav_basename)
self.badcase_list.append(wav_basename)
else:
energy = result
if len(mel_dict) > 0:
energy = align_length(energy,
mel_dict.get(wav_basename, None))
if energy is None:
logging.warning(
'[AudioProcessor] Energy length mismatch with mel in %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
continue
self.energy_dict[wav_basename] = energy
logging.info('Melspec statistic proceeding...')
# Normalize energy
energy_mean = compute_mean(list(self.energy_dict.values()), dims=1)
energy_std = compute_std(
list(self.energy_dict.values()), energy_mean, dims=1)
np.savetxt(
os.path.join(out_energy_dir, 'energy_mean.txt'),
energy_mean,
fmt='%.6f')
np.savetxt(
os.path.join(out_energy_dir, 'energy_std.txt'),
energy_std,
fmt='%.6f')
logging.info(
'[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format(
os.path.join(out_energy_dir, 'energy_mean.txt'),
os.path.join(out_energy_dir, 'energy_std.txt'),
))
logging.info('[AudioProcessor] Energy mean std norm is proceeding...')
for wav_basename in self.energy_dict:
energy = self.energy_dict[wav_basename]
norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std)
self.energy_dict[wav_basename] = norm_energy
# save frame energy to a specific dir
for wav_basename in self.energy_dict:
np.save(
os.path.join(out_frame_energy_dir, wav_basename + '.npy'),
self.energy_dict[wav_basename].reshape(-1),
)
# phone level average
# if there is no duration then save the frame-level energy
if self.phone_level_feature and len(self.dur_dict) > 0:
with ProcessPoolExecutor(
max_workers=self.num_workers) as executor, tqdm(
total=len(self.energy_dict)) as progress:
futures = []
for wav_basename in self.energy_dict:
future = executor.submit(
average_by_duration,
self.energy_dict.get(wav_basename, None),
self.dur_dict.get(wav_basename, None),
)
future.add_done_callback(lambda p: progress.update())
futures.append((future, wav_basename))
for future, wav_basename in futures:
result = future.result()
if result is None:
logging.warning(
'[AudioProcessor] Energy extraction failed in phone level avg for: %s',
wav_basename,
)
self.badcase_list.append(wav_basename)
else:
avg_energy = result
self.energy_dict[wav_basename] = avg_energy
for wav_basename in self.energy_dict:
np.save(
os.path.join(out_energy_dir, wav_basename + '.npy'),
self.energy_dict[wav_basename].reshape(-1),
)
logging.info('[AudioProcessor] Energy normalization finished')
logging.info('[AudioProcessor] Normed Energy saved to %s',
out_energy_dir)
logging.info('[AudioProcessor] Energy extraction finished')
return True
def process(self, src_voice_dir, out_data_dir, aux_metafile=None):
succeed = True
raw_wav_dir = os.path.join(src_voice_dir, 'wav')
src_interval_dir = os.path.join(src_voice_dir, 'interval')
out_mel_dir = os.path.join(out_data_dir, 'mel')
out_f0_dir = os.path.join(out_data_dir, 'f0')
out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0')
out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv')
out_energy_dir = os.path.join(out_data_dir, 'energy')
out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy')
out_duration_dir = os.path.join(out_data_dir, 'raw_duration')
out_cali_duration_dir = os.path.join(out_data_dir, 'duration')
os.makedirs(out_data_dir, exist_ok=True)
with_duration = os.path.exists(src_interval_dir)
train_wav_dir = os.path.join(out_data_dir, 'wav')
succeed = self.amp_normalize(raw_wav_dir, train_wav_dir)
if not succeed:
logging.error('[AudioProcessor] amp_normalize failed, exit')
return False
if with_duration:
# Raw duration, non-trimmed
succeed = self.duration_generate(src_interval_dir,
out_duration_dir)
if not succeed:
logging.error(
'[AudioProcessor] duration_generate failed, exit')
return False
if self.trim_silence:
if with_duration:
succeed = self.trim_silence_wav_with_interval(
train_wav_dir, out_duration_dir)
if not succeed:
logging.error(
'[AudioProcessor] trim_silence_wav_with_interval failed, exit'
)
return False
else:
succeed = self.trim_silence_wav(train_wav_dir)
if not succeed:
logging.error(
'[AudioProcessor] trim_silence_wav failed, exit')
return False
succeed = self.mel_extract(train_wav_dir, out_mel_dir)
if not succeed:
logging.error('[AudioProcessor] mel_extract failed, exit')
return False
if aux_metafile is not None and with_duration:
self.calibrate_SyllableDuration(out_duration_dir, aux_metafile,
out_cali_duration_dir)
succeed = self.pitch_extract(train_wav_dir, out_f0_dir,
out_frame_f0_dir, out_frame_uv_dir)
if not succeed:
logging.error('[AudioProcessor] pitch_extract failed, exit')
return False
succeed = self.energy_extract(train_wav_dir, out_energy_dir,
out_frame_energy_dir)
if not succeed:
logging.error('[AudioProcessor] energy_extract failed, exit')
return False
# recording badcase list
with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f:
f.write('\n'.join(self.badcase_list))
logging.info('[AudioProcessor] All features extracted successfully!')
return succeed

View File

@@ -0,0 +1,240 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
def _stft(y, hop_length, win_length, n_fft):
return librosa.stft(
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
def _istft(y, hop_length, win_length):
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
def _db_to_amp(x):
return np.power(10.0, x * 0.05)
def _amp_to_db(x):
return 20 * np.log10(np.maximum(1e-5, x))
def load_wav(path, sr):
return librosa.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
if wav.dtype == np.float32 or wav.dtype == np.float64:
quant_wav = 32767 * wav
else:
quant_wav = wav
# maxmize the volume to avoid clipping
# wav *= 32767 / max(0.01, np.max(np.abs(wav)))
wavfile.write(path, sr, quant_wav.astype(np.int16))
def trim_silence(wav, top_db, hop_length, win_length):
trimed_wav, _ = librosa.effects.trim(
wav, top_db=top_db, frame_length=win_length, hop_length=hop_length)
return trimed_wav
def trim_silence_with_interval(wav, interval, hop_length):
if interval is None:
return None
leading_sil = interval[0]
tailing_sil = interval[-1]
trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length]
return trim_wav
def preemphasis(wav, k=0.98, preemphasize=False):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k=0.98, inv_preemphasize=False):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False):
if symmetric:
return np.clip(
(2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm,
-max_norm,
max_norm,
)
else:
return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0,
max_norm)
def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False):
if symmetric:
return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db
/ # noqa W504
(2 * max_norm)) + min_level_db
else:
return (np.clip(D, 0, max_norm) * -min_level_db
/ max_norm) + min_level_db
def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60):
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
S_complex = np.abs(S).astype(np.complex)
y = _istft(
S_complex * angles, hop_length=hop_length, win_length=win_length)
for i in range(griffin_lim_iters):
angles = np.exp(1j * np.angle(
_stft(
y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)))
y = _istft(
S_complex * angles, hop_length=hop_length, win_length=win_length)
return y
def spectrogram(
y,
n_fft=1024,
hop_length=256,
win_length=1024,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
symmetric=False,
):
D = _stft(preemphasis(y), hop_length, win_length, n_fft)
S = _amp_to_db(np.abs(D)) - ref_level_db
return _normalize(S, max_norm, min_level_db, symmetric)
def inv_spectrogram(
spectrogram,
n_fft=1024,
hop_length=256,
win_length=1024,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
symmetric=False,
power=1.5,
):
S = _db_to_amp(
_denormalize(spectrogram, max_norm, min_level_db, symmetric)
+ ref_level_db)
return _griffin_lim(S**power, n_fft, hop_length, win_length)
def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80):
assert fmax <= sample_rate // 2
return librosa.filters.mel(
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
# mel linear Conversions
_mel_basis = None
_inv_mel_basis = None
def _linear_to_mel(spectogram,
sample_rate,
n_fft=1024,
fmin=50,
fmax=8000,
n_mels=80):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)
return np.dot(_mel_basis, spectogram)
def _mel_to_linear(mel_spectrogram,
sample_rate,
n_fft=1024,
fmin=50,
fmax=8000,
n_mels=80):
global _inv_mel_basis
if _inv_mel_basis is None:
_inv_mel_basis = np.linalg.pinv(
_build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels))
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
def melspectrogram(
y,
sample_rate,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mels=80,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
fmin=50,
fmax=8000,
symmetric=False,
preemphasize=False,
):
D = _stft(
preemphasis(y, preemphasize=preemphasize),
hop_length=hop_length,
win_length=win_length,
n_fft=n_fft,
)
S = (
_amp_to_db(
_linear_to_mel(
np.abs(D),
sample_rate=sample_rate,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
n_mels=n_mels,
)) - ref_level_db)
return _normalize(
S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T
def inv_mel_spectrogram(
mel_spectrogram,
sample_rate,
n_fft=1024,
hop_length=256,
win_length=1024,
n_mels=80,
max_norm=1.0,
min_level_db=-100,
ref_level_db=20,
fmin=50,
fmax=8000,
power=1.5,
symmetric=False,
preemphasize=False,
):
D = _denormalize(
mel_spectrogram,
max_norm=max_norm,
min_level_db=min_level_db,
symmetric=symmetric,
)
S = _mel_to_linear(
_db_to_amp(D + ref_level_db),
sample_rate=sample_rate,
n_fft=n_fft,
fmin=fmin,
fmax=fmax,
n_mels=n_mels,
)
return inv_preemphasis(
_griffin_lim(S**power, n_fft, hop_length, win_length),
preemphasize=preemphasize,
)

View File

@@ -0,0 +1,480 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from concurrent.futures import ProcessPoolExecutor
from glob import glob
import librosa
import numpy as np
import pysptk
import sox
from scipy.io import wavfile
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .dsp import _stft
logging = get_logger()
anchor_hist = np.array([
0.0,
0.00215827,
0.00354383,
0.00442313,
0.00490274,
0.00532907,
0.00602185,
0.00690115,
0.00810019,
0.00948574,
0.0120437,
0.01489475,
0.01873168,
0.02302158,
0.02872369,
0.03669065,
0.04636291,
0.05843325,
0.07700506,
0.11052491,
0.16802558,
0.25997868,
0.37942979,
0.50730083,
0.62006395,
0.71092459,
0.76877165,
0.80762057,
0.83458566,
0.85672795,
0.87660538,
0.89251266,
0.90578204,
0.91569411,
0.92541966,
0.93383959,
0.94162004,
0.94940048,
0.95539568,
0.96136424,
0.9670397,
0.97290168,
0.97705835,
0.98116174,
0.98465228,
0.98814282,
0.99152678,
0.99421796,
0.9965894,
0.99840128,
1.0,
])
anchor_bins = np.array([
0.033976,
0.03529014,
0.03660428,
0.03791842,
0.03923256,
0.0405467,
0.04186084,
0.04317498,
0.04448912,
0.04580326,
0.0471174,
0.04843154,
0.04974568,
0.05105982,
0.05237396,
0.0536881,
0.05500224,
0.05631638,
0.05763052,
0.05894466,
0.0602588,
0.06157294,
0.06288708,
0.06420122,
0.06551536,
0.0668295,
0.06814364,
0.06945778,
0.07077192,
0.07208606,
0.0734002,
0.07471434,
0.07602848,
0.07734262,
0.07865676,
0.0799709,
0.08128504,
0.08259918,
0.08391332,
0.08522746,
0.0865416,
0.08785574,
0.08916988,
0.09048402,
0.09179816,
0.0931123,
0.09442644,
0.09574058,
0.09705472,
0.09836886,
0.099683,
])
hist_bins = 50
def amp_info(wav_file_path):
"""
Returns the amplitude info of the wav file.
"""
stats = sox.file_info.stat(wav_file_path)
amp_rms = stats['RMS amplitude']
amp_max = stats['Maximum amplitude']
amp_mean = stats['Mean amplitude']
length = stats['Length (seconds)']
return {
'amp_rms': amp_rms,
'amp_max': amp_max,
'amp_mean': amp_mean,
'length': length,
'basename': os.path.basename(wav_file_path),
}
def statistic_amplitude(src_wav_dir):
"""
Returns the amplitude info of the wav file.
"""
wav_lst = glob(os.path.join(src_wav_dir, '*.wav'))
with ProcessPoolExecutor(max_workers=8) as executor, tqdm(
total=len(wav_lst)) as progress:
futures = []
for wav_file_path in wav_lst:
future = executor.submit(amp_info, wav_file_path)
future.add_done_callback(lambda p: progress.update())
futures.append(future)
amp_info_lst = [future.result() for future in futures]
amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms'])
logging.info('Average amplitude RMS : {}'.format(
np.mean([x['amp_rms'] for x in amp_info_lst])))
return amp_info_lst
def volume_normalize(src_wav_dir, out_wav_dir):
logging.info('Volume statistic proceeding...')
amp_info_lst = statistic_amplitude(src_wav_dir)
logging.info('Volume statistic done.')
rms_amp_lst = [x['amp_rms'] for x in amp_info_lst]
src_hist, src_bins = np.histogram(
rms_amp_lst, bins=hist_bins, density=True)
src_hist = src_hist / np.sum(src_hist)
src_hist = np.cumsum(src_hist)
src_hist = np.insert(src_hist, 0, 0.0)
logging.info('Volume normalization proceeding...')
for amp_info in tqdm(amp_info_lst):
rms_amp = amp_info['amp_rms']
rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1])
src_idx = np.where(rms_amp >= src_bins)[0][-1]
src_pos = src_hist[src_idx]
anchor_idx = np.where(src_pos >= anchor_hist)[0][-1]
if src_idx == hist_bins or anchor_idx == hist_bins:
rms_amp = anchor_bins[-1]
else:
rms_amp = (rms_amp - src_bins[src_idx]) / (
src_bins[src_idx + 1] - src_bins[src_idx]) * (
anchor_bins[anchor_idx + 1]
- anchor_bins[anchor_idx]) + anchor_bins[anchor_idx]
scale = rms_amp / amp_info['amp_rms']
# FIXME: This is a hack to avoid the sound cliping.
sr, data = wavfile.read(
os.path.join(src_wav_dir, amp_info['basename']))
wavfile.write(
os.path.join(out_wav_dir, amp_info['basename']),
sr,
(data * scale).astype(np.int16),
)
logging.info('Volume normalization done.')
return True
def interp_f0(f0_data):
"""
linear interpolation
"""
f0_data[f0_data < 1] = 0
xp = np.nonzero(f0_data)
yp = f0_data[xp]
x = np.arange(f0_data.size)
contour_f0 = np.interp(x, xp[0], yp).astype(np.float32)
return contour_f0
def frame_nccf(x, y):
norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5
return (np.sum(x * y) / norm_coef + 1.0) / 2.0
def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000):
if pcm_data.dtype == np.int16:
pcm_data = pcm_data.astype(np.float32) / 32768
frame_len = int(sr / 200)
frame_num = int(len(pcm_data) // fs)
frame_num = min(frame_num, len(f0))
pad_len = int(sr / min_f0) + frame_len
pad_zeros = np.zeros([pad_len], dtype=np.float32)
data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros))
nccf = np.zeros((frame_num), dtype=np.float32)
for i in range(frame_num):
curr_f0 = np.clip(f0[i], min_f0, max_f0)
lag = int(sr / curr_f0 + 0.5)
j = i * fs + pad_len - frame_len // 2
l_data = data[j:j + frame_len]
l_data -= l_data.mean()
r_data = data[j + lag:j + lag + frame_len]
r_data -= r_data.mean()
nccf[i] = frame_nccf(l_data, r_data)
return nccf
def smooth(data, win_len):
if win_len % 2 == 0:
win_len += 1
hwin = win_len // 2
win = np.hanning(win_len)
win /= win.sum()
data = data.reshape([-1])
pad_data = np.pad(data, hwin, mode='edge')
for i in range(data.shape[0]):
data[i] = np.dot(win, pad_data[i:i + win_len])
return data.reshape([-1, 1])
# support: rapt, swipe
# unsupport: reaper, world(DIO)
def RAPT_FUNC(v1, v2, v3, v4, v5):
return pysptk.sptk.rapt(
v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5)
def SWIPE_FUNC(v1, v2, v3, v4, v5):
return pysptk.sptk.swipe(
v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5)
def PYIN_FUNC(v1, v2, v3, v4, v5):
f0_mel = librosa.pyin(
v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0]
f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel)
return f0_mel
def get_pitch(pcm_data, sampling_rate=16000, hop_length=160):
log_f0_list = []
uv_list = []
low, high = 40, 800
cali_f0 = pysptk.sptk.rapt(
pcm_data.astype(np.float32),
fs=sampling_rate,
hopsize=hop_length,
min=low,
max=high,
)
f0_range = np.sort(np.unique(cali_f0))
if len(f0_range) > 20:
low = max(f0_range[10] - 50, low)
high = min(f0_range[-10] + 50, high)
func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC}
for func_name in func_dict:
f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low,
high)
uv = f0 > 0
if len(f0) < 10 or f0.max() < low:
logging.error('{} method: calc F0 is too low.'.format(func_name))
continue
else:
f0 = np.clip(f0, 1e-30, high)
log_f0 = np.log(f0)
contour_log_f0 = interp_f0(log_f0)
log_f0_list.append(contour_log_f0)
uv_list.append(uv)
if len(log_f0_list) == 0:
logging.error('F0 estimation failed.')
return None
min_len = float('inf')
for log_f0 in log_f0_list:
min_len = min(min_len, log_f0.shape[0])
multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
for i in range(len(log_f0_list)):
multi_log_f0[i, :] = log_f0_list[i][:min_len]
multi_uv[i, :] = uv_list[i][:min_len]
log_f0 = smooth(np.median(multi_log_f0, axis=0), 5)
uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32)
f0 = np.exp(log_f0)
min_len = min(f0.shape[0], uv.shape[0])
return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len]
def get_energy(pcm_data, hop_length, win_length, n_fft):
D = _stft(pcm_data, hop_length, win_length, n_fft)
S, _ = librosa.magphase(D)
energy = np.sqrt(np.sum(S**2, axis=0))
return energy.reshape((-1, 1))
def align_length(in_data, tgt_data, basename=None):
if in_data is None or tgt_data is None:
logging.error('{}: Input data is None.'.format(basename))
return None
in_len = in_data.shape[0]
tgt_len = tgt_data.shape[0]
if abs(in_len - tgt_len) > 20:
logging.error(
'{}: Input data length mismatches with target data length too much.'
.format(basename))
return None
if in_len < tgt_len:
out_data = np.pad(
in_data, ((0, tgt_len - in_len), (0, 0)),
'constant',
constant_values=0.0)
else:
out_data = in_data[:tgt_len]
return out_data
def compute_mean(data_list, dims=80):
mean_vector = np.zeros((1, dims))
all_frame_number = 0
for data in tqdm(data_list):
if data is None:
continue
features = data.reshape((-1, dims))
current_frame_number = np.shape(features)[0]
mean_vector += np.sum(features[:, :], axis=0)
all_frame_number += current_frame_number
mean_vector /= float(all_frame_number)
return mean_vector
def compute_std(data_list, mean_vector, dims=80):
std_vector = np.zeros((1, dims))
all_frame_number = 0
for data in tqdm(data_list):
if data is None:
continue
features = data.reshape((-1, dims))
current_frame_number = np.shape(features)[0]
mean_matrix = np.tile(mean_vector, (current_frame_number, 1))
std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0)
all_frame_number += current_frame_number
std_vector /= float(all_frame_number)
std_vector = std_vector**0.5
return std_vector
def f0_norm_mean_std(x, mean, std):
zero_idxs = np.where(x == 0.0)[0]
x = (x - mean) / std
x[zero_idxs] = 0.0
return x
def norm_mean_std(x, mean, std):
x = (x - mean) / std
return x
def parse_interval_file(file_path, sampling_rate, hop_length):
with open(file_path, 'r') as f:
lines = f.readlines()
# second
frame_intervals = 1.0 * hop_length / sampling_rate
skip_lines = 12
dur_list = []
phone_list = []
line_index = skip_lines
while line_index < len(lines):
phone_begin = float(lines[line_index])
phone_end = float(lines[line_index + 1])
phone = lines[line_index + 2].strip()[1:-1]
dur_list.append(
int(round((phone_end - phone_begin) / frame_intervals)))
phone_list.append(phone)
line_index += 3
if len(dur_list) == 0 or len(phone_list) == 0:
return None
return np.array(dur_list), phone_list
def average_by_duration(x, durs):
if x is None or durs is None:
return None
durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant'))
# average over each symbol's duraion
x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32)
for idx, start, end in zip(
range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]):
values = x[start:end][np.where(x[start:end] != 0.0)[0]]
x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0
return x_symbol.astype(np.float32)
def encode_16bits(x):
if x.min() > -1.0 and x.max() < 1.0:
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
else:
return x

View File

@@ -0,0 +1,177 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import codecs
import os
import sys
import time
import yaml
from modelscope import __version__
from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset,
VocDataset)
from modelscope.utils.logger import get_logger
from .audio_processor.audio_processor import AudioProcessor
from .fp_processor import FpProcessor, is_fp_line
from .languages import languages
from .script_convertor.text_script_convertor import TextScriptConvertor
ROOT_PATH = os.path.dirname(os.path.dirname(
os.path.abspath(__file__))) # NOQA: E402
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
logging = get_logger()
def gen_metafile(
voice_output_dir,
fp_enable=False,
badlist=None,
split_ratio=0.98,
):
voc_train_meta = os.path.join(voice_output_dir, 'train.lst')
voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst')
if not os.path.exists(voc_train_meta) or not os.path.exists(
voc_valid_meta):
VocDataset.gen_metafile(
os.path.join(voice_output_dir, 'wav'),
voice_output_dir,
split_ratio,
)
logging.info('Voc metafile generated.')
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta):
AmDataset.gen_metafile(
raw_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM metafile generated.')
if fp_enable:
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(
am_valid_meta):
AmDataset.gen_metafile(
fpadd_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM fpaddmetafile generated.')
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst')
am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst')
if not os.path.exists(am_train_meta) or not os.path.exists(
am_valid_meta):
AmDataset.gen_metafile(
fprm_metafile,
voice_output_dir,
am_train_meta,
am_valid_meta,
badlist,
split_ratio,
)
logging.info('AM fprmmetafile generated.')
def process_data(
voice_input_dir,
voice_output_dir,
language_dir,
audio_config,
speaker_name=None,
targetLang='PinYin',
skip_script=False,
):
foreignLang = 'EnUS'
emo_tag_path = None
phoneset_path = os.path.join(language_dir, targetLang,
languages[targetLang]['phoneset_path'])
posset_path = os.path.join(language_dir, targetLang,
languages[targetLang]['posset_path'])
f2t_map_path = os.path.join(language_dir, targetLang,
languages[targetLang]['f2t_map_path'])
s2p_map_path = os.path.join(language_dir, targetLang,
languages[targetLang]['s2p_map_path'])
logging.info(f'phoneset_path={phoneset_path}')
if speaker_name is None:
speaker_name = os.path.basename(voice_input_dir)
if audio_config is not None:
with open(audio_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime())
config['modelscope_version'] = __version__
with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
if skip_script:
logging.info('Skip script conversion')
raw_metafile = None
# Script processor
if not skip_script:
tsc = TextScriptConvertor(
phoneset_path,
posset_path,
targetLang,
foreignLang,
f2t_map_path,
s2p_map_path,
emo_tag_path,
speaker_name,
)
tsc.process(
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
os.path.join(voice_output_dir, 'Script.xml'),
os.path.join(voice_output_dir, 'raw_metafile.txt'),
)
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
# FP processor
with codecs.open(prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
if fp_enable:
FP = FpProcessor()
FP.process(
voice_output_dir,
prosody,
raw_metafile,
)
logging.info('Processing fp done.')
# Audio processor
ap = AudioProcessor(config['audio_config'])
ap.process(
voice_input_dir,
voice_output_dir,
raw_metafile,
)
logging.info('Processing done.')
# Generate Voc&AM metafile
gen_metafile(voice_output_dir, fp_enable, ap.badcase_list)

View File

@@ -0,0 +1,128 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import random
from modelscope.utils.logger import get_logger
logging = get_logger()
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
class FpProcessor:
def __init__(self):
self.res = []
def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
fp_category_list = ['FP', 'I', 'N']
f = open(prosody)
prosody_lines = f.readlines()
f.close()
idx = ''
fp = ''
fp_label_dict = {}
for i in range(len(prosody_lines)):
if i % 5 == 0:
idx = prosody_lines[i].strip().split('\t')[0]
elif i % 5 == 1: # according to prosody.txt
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
for label in fp:
if label not in fp_category_list:
logging.warning('fp label not in fp_category_list')
break
fp_label_dict[idx] = fp
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
f_out = open(fpadd_metafile, 'w')
for line in raw_metafile_lines:
tokens = line.strip().split('\t')
if len(tokens) == 2:
uttname = tokens[0]
symbol_sequences = tokens[1].split(' ')
error_flag = False
idx = 0
out_str = uttname + '\t'
for this_symbol_sequence in symbol_sequences:
emotion = this_symbol_sequence.split('$')[4]
this_symbol_sequence = this_symbol_sequence.replace(
emotion, 'emotion_neutral')
if idx < len(fp_label_dict[uttname]):
if fp_label_dict[uttname][idx] == 'FP':
if 'none' not in this_symbol_sequence:
this_symbol_sequence = this_symbol_sequence.replace(
'emotion_neutral', 'emotion_disgust')
syllable_label = this_symbol_sequence.split('$')[2]
if syllable_label == 's_both' or syllable_label == 's_end':
idx += 1
elif idx > len(fp_label_dict[uttname]):
logging.warning(uttname + ' not match')
error_flag = True
out_str = out_str + this_symbol_sequence + ' '
if idx != len(fp_label_dict[uttname]):
logging.warning('{} length mismatch, length: {} '.format(
idx, len(fp_label_dict[uttname])))
if not error_flag:
f_out.write(out_str.strip() + '\n')
f_out.close()
return fpadd_metafile
def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines):
f = open(fpadd_metafile)
fpadd_metafile_lines = f.readlines()
f.close()
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
f_out = open(fprm_metafile, 'w')
for i in range(len(raw_metafile_lines)):
tokens = raw_metafile_lines[i].strip().split('\t')
symbol_sequences = tokens[1].split(' ')
fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t')
fpadd_symbol_sequences = fpadd_tokens[1].split(' ')
error_flag = False
out_str = tokens[0] + '\t'
idx = 0
length = len(symbol_sequences)
while idx < length:
if '$emotion_disgust' in fpadd_symbol_sequences[idx]:
if idx + 1 < length and 'none' in fpadd_symbol_sequences[
idx + 1]:
idx = idx + 2
else:
idx = idx + 1
continue
out_str = out_str + symbol_sequences[idx] + ' '
idx = idx + 1
if not error_flag:
f_out.write(out_str.strip() + '\n')
f_out.close()
def process(self, voice_output_dir, prosody, raw_metafile):
with open(raw_metafile, 'r') as f:
lines = f.readlines()
random.shuffle(lines)
fpadd_metafile = self.addfp(voice_output_dir, prosody, lines)
self.removefp(voice_output_dir, fpadd_metafile, lines)

View File

@@ -0,0 +1,46 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
languages = {
'PinYin': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'ZhHK': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'WuuShanghai': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'Sichuan': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': 'En2ChPhoneMap.txt',
's2p_map_path': 'py2phoneMap.txt',
'tonelist_path': 'tonelist.txt',
},
'EnGB': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': '',
's2p_map_path': '',
'tonelist_path': 'tonelist.txt',
},
'EnUS': {
'phoneset_path': 'PhoneSet.xml',
'posset_path': 'PosSet.xml',
'f2t_map_path': '',
's2p_map_path': '',
'tonelist_path': 'tonelist.txt',
}
}

View File

@@ -0,0 +1,242 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from enum import Enum
class Tone(Enum):
UnAssigned = -1
NoneTone = 0
YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
ShangSheng = 3 # ZhHK: YinQuZhongRu
QuSheng = 4 # ZhHK: YangPing
QingSheng = 5 # ZhHK: YangShang
YangQuYangRu = 6 # ZhHK: YangQuYangRu
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Tone, cls).__new__(cls, in_str)
if in_str in ['UnAssigned', '-1']:
return Tone.UnAssigned
elif in_str in ['NoneTone', '0']:
return Tone.NoneTone
elif in_str in ['YinPing', '1']:
return Tone.YinPing
elif in_str in ['YangPing', '2']:
return Tone.YangPing
elif in_str in ['ShangSheng', '3']:
return Tone.ShangSheng
elif in_str in ['QuSheng', '4']:
return Tone.QuSheng
elif in_str in ['QingSheng', '5']:
return Tone.QingSheng
elif in_str in ['YangQuYangRu', '6']:
return Tone.YangQuYangRu
else:
return Tone.NoneTone
class BreakLevel(Enum):
UnAssigned = -1
L0 = 0
L1 = 1
L2 = 2
L3 = 3
L4 = 4
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(BreakLevel, cls).__new__(cls, in_str)
if in_str in ['UnAssigned', '-1']:
return BreakLevel.UnAssigned
elif in_str in ['L0', '0']:
return BreakLevel.L0
elif in_str in ['L1', '1']:
return BreakLevel.L1
elif in_str in ['L2', '2']:
return BreakLevel.L2
elif in_str in ['L3', '3']:
return BreakLevel.L3
elif in_str in ['L4', '4']:
return BreakLevel.L4
else:
return BreakLevel.UnAssigned
class SentencePurpose(Enum):
Declarative = 0
Interrogative = 1
Exclamatory = 2
Imperative = 3
class Language(Enum):
Neutral = 0
EnUS = 1033
EnGB = 2057
ZhCN = 2052
PinYin = 2053
WuuShanghai = 2054
Sichuan = 2055
ZhHK = 3076
ZhEn = ZhCN | EnUS
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(Language, cls).__new__(cls, in_str)
if in_str in ['Neutral', '0']:
return Language.Neutral
elif in_str in ['EnUS', '1033']:
return Language.EnUS
elif in_str in ['EnGB', '2057']:
return Language.EnGB
elif in_str in ['ZhCN', '2052']:
return Language.ZhCN
elif in_str in ['PinYin', '2053']:
return Language.PinYin
elif in_str in ['WuuShanghai', '2054']:
return Language.WuuShanghai
elif in_str in ['Sichuan', '2055']:
return Language.Sichuan
elif in_str in ['ZhHK', '3076']:
return Language.ZhHK
elif in_str in ['ZhEn', '2052|1033']:
return Language.ZhEn
else:
return Language.Neutral
"""
Phone Types
"""
class PhoneCVType(Enum):
NULL = -1
Consonant = 1
Vowel = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneCVType, cls).__new__(cls, in_str)
if in_str in ['consonant', 'Consonant']:
return PhoneCVType.Consonant
elif in_str in ['vowel', 'Vowel']:
return PhoneCVType.Vowel
else:
return PhoneCVType.NULL
class PhoneIFType(Enum):
NULL = -1
Initial = 1
Final = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneIFType, cls).__new__(cls, in_str)
if in_str in ['initial', 'Initial']:
return PhoneIFType.Initial
elif in_str in ['final', 'Final']:
return PhoneIFType.Final
else:
return PhoneIFType.NULL
class PhoneUVType(Enum):
NULL = -1
Voiced = 1
UnVoiced = 2
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneUVType, cls).__new__(cls, in_str)
if in_str in ['voiced', 'Voiced']:
return PhoneUVType.Voiced
elif in_str in ['unvoiced', 'UnVoiced']:
return PhoneUVType.UnVoiced
else:
return PhoneUVType.NULL
class PhoneAPType(Enum):
NULL = -1
DoubleLips = 1
LipTooth = 2
FrontTongue = 3
CentralTongue = 4
BackTongue = 5
Dorsal = 6
Velar = 7
Low = 8
Middle = 9
High = 10
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAPType, cls).__new__(cls, in_str)
if in_str in ['doublelips', 'DoubleLips']:
return PhoneAPType.DoubleLips
elif in_str in ['liptooth', 'LipTooth']:
return PhoneAPType.LipTooth
elif in_str in ['fronttongue', 'FrontTongue']:
return PhoneAPType.FrontTongue
elif in_str in ['centraltongue', 'CentralTongue']:
return PhoneAPType.CentralTongue
elif in_str in ['backtongue', 'BackTongue']:
return PhoneAPType.BackTongue
elif in_str in ['dorsal', 'Dorsal']:
return PhoneAPType.Dorsal
elif in_str in ['velar', 'Velar']:
return PhoneAPType.Velar
elif in_str in ['low', 'Low']:
return PhoneAPType.Low
elif in_str in ['middle', 'Middle']:
return PhoneAPType.Middle
elif in_str in ['high', 'High']:
return PhoneAPType.High
else:
return PhoneAPType.NULL
class PhoneAMType(Enum):
NULL = -1
Stop = 1
Affricate = 2
Fricative = 3
Nasal = 4
Lateral = 5
Open = 6
Close = 7
@classmethod
def parse(cls, in_str):
if not isinstance(in_str, str):
return super(PhoneAMType, cls).__new__(cls, in_str)
if in_str in ['stop', 'Stop']:
return PhoneAMType.Stop
elif in_str in ['affricate', 'Affricate']:
return PhoneAMType.Affricate
elif in_str in ['fricative', 'Fricative']:
return PhoneAMType.Fricative
elif in_str in ['nasal', 'Nasal']:
return PhoneAMType.Nasal
elif in_str in ['lateral', 'Lateral']:
return PhoneAMType.Lateral
elif in_str in ['open', 'Open']:
return PhoneAMType.Open
elif in_str in ['close', 'Close']:
return PhoneAMType.Close
else:
return PhoneAMType.NULL

View File

@@ -0,0 +1,48 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType,
PhoneUVType)
from .xml_obj import XmlObj
class Phone(XmlObj):
def __init__(self):
self.m_id = None
self.m_name = None
self.m_cv_type = PhoneCVType.NULL
self.m_if_type = PhoneIFType.NULL
self.m_uv_type = PhoneUVType.NULL
self.m_ap_type = PhoneAPType.NULL
self.m_am_type = PhoneAMType.NULL
self.m_bnd = False
def __str__(self):
return self.m_name
def save(self):
pass
def load(self, phone_node):
ns = '{http://schemas.alibaba-inc.com/tts}'
id_node = phone_node.find(ns + 'id')
self.m_id = int(id_node.text)
name_node = phone_node.find(ns + 'name')
self.m_name = name_node.text
cv_node = phone_node.find(ns + 'cv')
self.m_cv_type = PhoneCVType.parse(cv_node.text)
if_node = phone_node.find(ns + 'if')
self.m_if_type = PhoneIFType.parse(if_node.text)
uv_node = phone_node.find(ns + 'uv')
self.m_uv_type = PhoneUVType.parse(uv_node.text)
ap_node = phone_node.find(ns + 'ap')
self.m_ap_type = PhoneAPType.parse(ap_node.text)
am_node = phone_node.find(ns + 'am')
self.m_am_type = PhoneAMType.parse(am_node.text)

View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from modelscope.utils.logger import get_logger
from .phone import Phone
from .xml_obj import XmlObj
logging = get_logger()
class PhoneSet(XmlObj):
def __init__(self, phoneset_path):
self.m_phone_list = []
self.m_id_map = {}
self.m_name_map = {}
self.load(phoneset_path)
def load(self, file_path):
# alibaba tts xml namespace
ns = '{http://schemas.alibaba-inc.com/tts}'
phoneset_root = ET.parse(file_path).getroot()
for phone_node in phoneset_root.findall(ns + 'phone'):
phone = Phone()
phone.load(phone_node)
self.m_phone_list.append(phone)
if phone.m_id in self.m_id_map:
logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id)
self.m_id_map[phone.m_id] = phone
if phone.m_name in self.m_name_map:
logging.error('PhoneSet.Load duplicate name name: %s',
phone.m_name)
self.m_name_map[phone.m_name] = phone
def save(self):
pass

View File

@@ -0,0 +1,43 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .xml_obj import XmlObj
class Pos(XmlObj):
def __init__(self):
self.m_id = None
self.m_name = None
self.m_desc = None
self.m_level = 1
self.m_parent = None
self.m_sub_pos_list = []
def __str__(self):
return self.m_name
def save(self):
pass
def load(self, pos_node):
ns = '{http://schemas.alibaba-inc.com/tts}'
id_node = pos_node.find(ns + 'id')
self.m_id = int(id_node.text)
name_node = pos_node.find(ns + 'name')
self.m_name = name_node.text
desc_node = pos_node.find(ns + 'desc')
self.m_desc = desc_node.text
sub_node = pos_node.find(ns + 'sub')
if sub_node is not None:
for sub_pos_node in sub_node.findall(ns + 'pos'):
sub_pos = Pos()
sub_pos.load(sub_pos_node)
sub_pos.m_parent = self
sub_pos.m_level = self.m_level + 1
self.m_sub_pos_list.append(sub_pos)
return

View File

@@ -0,0 +1,50 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import xml.etree.ElementTree as ET
from .pos import Pos
from .xml_obj import XmlObj
class PosSet(XmlObj):
def __init__(self, posset_path):
self.m_pos_list = []
self.m_id_map = {}
self.m_name_map = {}
self.load(posset_path)
def load(self, file_path):
# alibaba tts xml namespace
ns = '{http://schemas.alibaba-inc.com/tts}'
posset_root = ET.parse(file_path).getroot()
for pos_node in posset_root.findall(ns + 'pos'):
pos = Pos()
pos.load(pos_node)
self.m_pos_list.append(pos)
if pos.m_id in self.m_id_map:
logging.error('PosSet.Load: duplicate id: %d', pos.m_id)
self.m_id_map[pos.m_id] = pos
if pos.m_name in self.m_name_map:
logging.error('PosSet.Load duplicate name name: %s',
pos.m_name)
self.m_name_map[pos.m_name] = pos
if len(pos.m_sub_pos_list) > 0:
for sub_pos in pos.m_sub_pos_list:
self.m_pos_list.append(sub_pos)
if sub_pos.m_id in self.m_id_map:
logging.error('PosSet.Load: duplicate id: %d',
sub_pos.m_id)
self.m_id_map[sub_pos.m_id] = sub_pos
if sub_pos.m_name in self.m_name_map:
logging.error('PosSet.Load duplicate name name: %s',
sub_pos.m_name)
self.m_name_map[sub_pos.m_name] = sub_pos
def save(self):
pass

View File

@@ -0,0 +1,35 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from xml.dom import minidom
from .xml_obj import XmlObj
class Script(XmlObj):
def __init__(self, phoneset, posset):
self.m_phoneset = phoneset
self.m_posset = posset
self.m_items = []
def save(self, outputXMLPath):
root = ET.Element('script')
root.set('uttcount', str(len(self.m_items)))
root.set('xmlns', 'http://schemas.alibaba-inc.com/tts')
for item in self.m_items:
item.save(root)
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(
indent=' ', encoding='utf-8')
with open(outputXMLPath, 'wb') as f:
f.write(xmlstr)
def save_meta_file(self):
meta_lines = []
for item in self.m_items:
meta_lines.append(item.save_metafile())
return meta_lines

View File

@@ -0,0 +1,40 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class ScriptItem(XmlObj):
def __init__(self, phoneset, posset):
if phoneset is None or posset is None:
raise Exception('ScriptItem.__init__: phoneset or posset is None')
self.m_phoneset = phoneset
self.m_posset = posset
self.m_id = None
self.m_text = ''
self.m_scriptSentence_list = []
self.m_status = None
def load(self):
pass
def save(self, parent_node):
utterance_node = ET.SubElement(parent_node, 'utterance')
utterance_node.set('id', self.m_id)
text_node = ET.SubElement(utterance_node, 'text')
text_node.text = self.m_text
for sentence in self.m_scriptSentence_list:
sentence.save(utterance_node)
def save_metafile(self):
meta_line = self.m_id + '\t'
for sentence in self.m_scriptSentence_list:
meta_line += sentence.save_metafile()
return meta_line

View File

@@ -0,0 +1,185 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class WrittenSentence(XmlObj):
def __init__(self, posset):
self.m_written_word_list = []
self.m_written_mark_list = []
self.m_posset = posset
self.m_align_list = []
self.m_alignCursor = 0
self.m_accompanyIndex = 0
self.m_sequence = ''
self.m_text = ''
def add_host(self, writtenWord):
self.m_written_word_list.append(writtenWord)
self.m_align_list.append(self.m_alignCursor)
def load_host(self):
pass
def save_host(self):
pass
def add_accompany(self, writtenMark):
self.m_written_mark_list.append(writtenMark)
self.m_alignCursor += 1
self.m_accompanyIndex += 1
def save_accompany(self):
pass
def load_accompany(self):
pass
# Get the mark span corresponding to specific spoken word
def get_accompany_span(self, host_index):
if host_index == -1:
return (0, self.m_align_list[0])
accompany_begin = self.m_align_list[host_index]
accompany_end = (
self.m_align_list[host_index + 1]
if host_index + 1 < len(self.m_written_word_list) else len(
self.m_written_mark_list))
return (accompany_begin, accompany_end)
def get_elements(self):
accompany_begin, accompany_end = self.get_accompany_span(-1)
res_lst = [
self.m_written_mark_list[i]
for i in range(accompany_begin, accompany_end)
]
for j in range(len(self.m_written_word_list)):
accompany_begin, accompany_end = self.get_accompany_span(j)
res_lst.extend([self.m_written_word_list[j]])
res_lst.extend([
self.m_written_mark_list[i]
for i in range(accompany_begin, accompany_end)
])
return res_lst
def build_sequence(self):
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
def build_text(self):
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
class SpokenSentence(XmlObj):
def __init__(self, phoneset):
self.m_spoken_word_list = []
self.m_spoken_mark_list = []
self.m_phoneset = phoneset
self.m_align_list = []
self.m_alignCursor = 0
self.m_accompanyIndex = 0
self.m_sequence = ''
self.m_text = ''
def __len__(self):
return len(self.m_spoken_word_list)
def add_host(self, spokenWord):
self.m_spoken_word_list.append(spokenWord)
self.m_align_list.append(self.m_alignCursor)
def save_host(self):
pass
def load_host(self):
pass
def add_accompany(self, spokenMark):
self.m_spoken_mark_list.append(spokenMark)
self.m_alignCursor += 1
self.m_accompanyIndex += 1
def save_accompany(self):
pass
# Get the mark span corresponding to specific spoken word
def get_accompany_span(self, host_index):
if host_index == -1:
return (0, self.m_align_list[0])
accompany_begin = self.m_align_list[host_index]
accompany_end = (
self.m_align_list[host_index + 1]
if host_index + 1 < len(self.m_spoken_word_list) else len(
self.m_spoken_mark_list))
return (accompany_begin, accompany_end)
def get_elements(self):
accompany_begin, accompany_end = self.get_accompany_span(-1)
res_lst = [
self.m_spoken_mark_list[i]
for i in range(accompany_begin, accompany_end)
]
for j in range(len(self.m_spoken_word_list)):
accompany_begin, accompany_end = self.get_accompany_span(j)
res_lst.extend([self.m_spoken_word_list[j]])
res_lst.extend([
self.m_spoken_mark_list[i]
for i in range(accompany_begin, accompany_end)
])
return res_lst
def load_accompany(self):
pass
def build_sequence(self):
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
def build_text(self):
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
def save(self, parent_node):
spoken_node = ET.SubElement(parent_node, 'spoken')
spoken_node.set('wordcount', str(len(self.m_spoken_word_list)))
text_node = ET.SubElement(spoken_node, 'text')
text_node.text = self.m_sequence
for word in self.m_spoken_word_list:
word.save(spoken_node)
def save_metafile(self):
meta_line_list = [
word.save_metafile() for word in self.m_spoken_word_list
]
return ' '.join(meta_line_list)
class ScriptSentence(XmlObj):
def __init__(self, phoneset, posset):
self.m_phoneset = phoneset
self.m_posset = posset
self.m_writtenSentence = WrittenSentence(posset)
self.m_spokenSentence = SpokenSentence(phoneset)
self.m_text = ''
def save(self, parent_node):
if len(self.m_spokenSentence) > 0:
self.m_spokenSentence.save(parent_node)
def save_metafile(self):
if len(self.m_spokenSentence) > 0:
return self.m_spokenSentence.save_metafile()
else:
return ''

View File

@@ -0,0 +1,120 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .core_types import Language
from .syllable import SyllableList
from .xml_obj import XmlObj
class WrittenWord(XmlObj):
def __init__(self):
self.m_name = None
self.m_POS = None
def __str__(self):
return self.m_name
def load(self):
pass
def save(self):
pass
class WrittenMark(XmlObj):
def __init__(self):
self.m_punctuation = None
def __str__(self):
return self.m_punctuation
def load(self):
pass
def save(self):
pass
class SpokenWord(XmlObj):
def __init__(self):
self.m_name = None
self.m_language = None
self.m_syllable_list = []
self.m_breakText = '1'
self.m_POS = '0'
def __str__(self):
return self.m_name
def load(self):
pass
def save(self, parent_node):
word_node = ET.SubElement(parent_node, 'word')
name_node = ET.SubElement(word_node, 'name')
name_node.text = self.m_name
if (len(self.m_syllable_list) > 0
and self.m_syllable_list[0].m_language != Language.Neutral):
language_node = ET.SubElement(word_node, 'lang')
language_node.text = self.m_syllable_list[0].m_language.name
SyllableList(self.m_syllable_list).save(word_node)
break_node = ET.SubElement(word_node, 'break')
break_node.text = self.m_breakText
POS_node = ET.SubElement(word_node, 'POS')
POS_node.text = self.m_POS
return
def save_metafile(self):
word_phone_cnt = sum(
[syllable.phone_count() for syllable in self.m_syllable_list])
word_syllable_cnt = len(self.m_syllable_list)
single_syllable_word = word_syllable_cnt == 1
meta_line_list = []
for idx, syll in enumerate(self.m_syllable_list):
if word_phone_cnt == 1:
word_pos = 'word_both'
elif idx == 0:
word_pos = 'word_begin'
elif idx == len(self.m_syllable_list) - 1:
word_pos = 'word_end'
else:
word_pos = 'word_middle'
meta_line_list.append(
syll.save_metafile(
word_pos, single_syllable_word=single_syllable_word))
if self.m_breakText != '0' and self.m_breakText is not None:
meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format(
self.m_breakText))
return ' '.join(meta_line_list)
class SpokenMark(XmlObj):
def __init__(self):
self.m_breakLevel = None
def break_level2text(self):
return '#' + str(self.m_breakLevel.value)
def __str__(self):
return self.break_level2text()
def load(self):
pass
def save(self):
pass

View File

@@ -0,0 +1,112 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import xml.etree.ElementTree as ET
from .xml_obj import XmlObj
class Syllable(XmlObj):
def __init__(self):
self.m_phone_list = []
self.m_tone = None
self.m_language = None
self.m_breaklevel = None
def pronunciation_text(self):
return ' '.join([str(phone) for phone in self.m_phone_list])
def phone_count(self):
return len(self.m_phone_list)
def tone_text(self):
return str(self.m_tone.value)
def save(self):
pass
def load(self):
pass
def get_phone_meta(self,
phone_name,
word_pos,
syll_pos,
tone_text,
single_syllable_word=False):
# Special case: word with single syllable, the last phone's word_pos should be "word_end"
if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word:
word_pos = 'word_end'
elif word_pos == 'word_begin' and syll_pos not in [
's_begin',
's_both',
]: # FIXME: keep accord with Engine logic
word_pos = 'word_middle'
elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']:
word_pos = 'word_middle'
else:
pass
return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos,
word_pos)
def save_metafile(self, word_pos, single_syllable_word=False):
syllable_phone_cnt = len(self.m_phone_list)
meta_line_list = []
for idx, phone in enumerate(self.m_phone_list):
if syllable_phone_cnt == 1:
syll_pos = 's_both'
elif idx == 0:
syll_pos = 's_begin'
elif idx == len(self.m_phone_list) - 1:
syll_pos = 's_end'
else:
syll_pos = 's_middle'
meta_line_list.append(
self.get_phone_meta(
phone,
word_pos,
syll_pos,
self.tone_text(),
single_syllable_word=single_syllable_word,
))
return ' '.join(meta_line_list)
class SyllableList(XmlObj):
def __init__(self, syllables):
self.m_syllable_list = syllables
def __len__(self):
return len(self.m_syllable_list)
def __index__(self, index):
return self.m_syllable_list[index]
def pronunciation_text(self):
return ' - '.join([
syllable.pronunciation_text() for syllable in self.m_syllable_list
])
def tone_text(self):
return ''.join(
[syllable.tone_text() for syllable in self.m_syllable_list])
def save(self, parent_node):
syllable_node = ET.SubElement(parent_node, 'syllable')
syllable_node.set('syllcount', str(len(self.m_syllable_list)))
phone_node = ET.SubElement(syllable_node, 'phone')
phone_node.text = self.pronunciation_text()
tone_node = ET.SubElement(syllable_node, 'tone')
tone_node.text = self.tone_text()
return
def load(self):
pass

View File

@@ -0,0 +1,322 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from modelscope.utils.logger import get_logger
from .core_types import Language, PhoneCVType, Tone
from .syllable import Syllable
from .utils import NgBreakPattern
logging = get_logger()
class DefaultSyllableFormatter:
def __init__(self):
return
def format(self, phoneset, pronText, syllable_list):
logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText)
return True
RegexNg2en = re.compile(NgBreakPattern)
RegexQingSheng = re.compile(r'([1-5]5)')
RegexPron = re.compile(r'(?P<Pron>[a-z]+)(?P<Tone>[1-6])')
class ZhCNSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def normalize_pron(self, pronText):
# Replace Qing Sheng
newPron = pronText.replace('6', '2')
newPron = re.sub(RegexQingSheng, '5', newPron)
# FIXME(Jin): ng case overrides newPron
match = RegexNg2en.search(newPron)
if match:
newPron = 'en' + match.group('break')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('ZhCNSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
if pronText in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pronText].split(' ')
if len(phone_list) == 3:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(
pronText[-1]) # FIXME(Jin): assume tone is the last char
syll.m_language = Language.ZhCN
syllable_list.append(syll)
return True
else:
logging.error(
'ZhCNSyllableFormatter.Format: invalid pronText: %s',
pronText)
return False
else:
logging.error(
'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s',
pronText,
)
return False
class PinYinSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def normalize_pron(self, pronText):
newPron = pronText.replace('6', '2')
newPron = re.sub(RegexQingSheng, '5', newPron)
# FIXME(Jin): ng case overrides newPron
match = RegexNg2en.search(newPron)
if match:
newPron = 'en' + match.group('break')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('PinYinSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'PinYinSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.PinYin
syllable_list.append(syll)
return True
else:
logging.error(
'PinYinSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class ZhHKSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('ZhHKSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s',
pronText)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.ZhHK
syllable_list.append(syll)
return True
else:
logging.error(
'ZhHKSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class WuuShanghaiSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('WuuShanghaiSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.WuuShanghai
syllable_list.append(syll)
return True
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: invalid phone: %s',
pron)
return False
else:
logging.error(
'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class SichuanSyllableFormatter:
def __init__(self, sy2ph_map):
self.m_sy2ph_map = sy2ph_map
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('SichuanSyllableFormatter.Format: invalid input')
return False
match = RegexPron.search(pronText)
if match:
pron = match.group('Pron')
tone = match.group('Tone')
else:
logging.error(
'SichuanSyllableFormatter.Format: pronunciation is not valid: %s',
pronText,
)
return False
if pron in self.m_sy2ph_map:
phone_list = self.m_sy2ph_map[pron].split(' ')
if len(phone_list) in [1, 2]:
syll = Syllable()
for phone in phone_list:
syll.m_phone_list.append(phone)
syll.m_tone = Tone.parse(tone)
syll.m_language = Language.Sichuan
syllable_list.append(syll)
return True
else:
logging.error(
'SichuanSyllableFormatter.Format: invalid phone: %s', pron)
return False
else:
logging.error(
'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s',
pron,
)
return False
class EnXXSyllableFormatter:
def __init__(self, language):
self.m_f2t_map = None
self.m_language = language
def normalize_pron(self, pronText):
newPron = pronText.replace('#', '.')
newPron = (
newPron.replace('03',
'0').replace('13',
'1').replace('23',
'2').replace('3', ''))
newPron = newPron.replace('2', '0')
return newPron
def format(self, phoneset, pronText, syllable_list):
if phoneset is None or syllable_list is None or pronText is None:
logging.error('EnXXSyllableFormatter.Format: invalid input')
return False
pronText = self.normalize_pron(pronText)
syllables = [ele.strip() for ele in pronText.split('.')]
for i in range(len(syllables)):
syll = Syllable()
syll.m_language = self.m_language
syll.m_tone = Tone.parse('0')
phones = re.split(r'[\s]+', syllables[i])
for j in range(len(phones)):
phoneName = phones[j].lower()
toneName = '0'
if '0' in phoneName or '1' in phoneName or '2' in phoneName:
toneName = phoneName[-1]
phoneName = phoneName[:-1]
phoneName_lst = None
if self.m_f2t_map is not None:
phoneName_lst = self.m_f2t_map.get(phoneName, None)
if phoneName_lst is None:
phoneName_lst = [phoneName]
for new_phoneName in phoneName_lst:
phone_obj = phoneset.m_name_map.get(new_phoneName, None)
if phone_obj is None:
logging.error(
'EnXXSyllableFormatter.Format: phone %s not found',
new_phoneName,
)
return False
phone_obj.m_name = new_phoneName
syll.m_phone_list.append(phone_obj)
if phone_obj.m_cv_type == PhoneCVType.Vowel:
syll.m_tone = Tone.parse(toneName)
if j == len(phones) - 1:
phone_obj.m_bnd = True
syllable_list.append(syll)
return True

View File

@@ -0,0 +1,112 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import codecs
import re
import unicodedata
WordPattern = r'((?P<Word>\w+)(\(\w+\))?)'
BreakPattern = r'(?P<Break>(\*?#(?P<BreakLevel>[0-4])))'
MarkPattern = r'(?P<Mark>[、,。!?:“”《》·])'
POSPattern = r'(?P<POS>(\*?\|(?P<POSClass>[1-9])))'
PhraseTonePattern = r'(?P<PhraseTone>(\*?%([L|H])))'
NgBreakPattern = r'^ng(?P<break>\d)'
RegexWord = re.compile(WordPattern + r'\s*')
RegexBreak = re.compile(BreakPattern + r'\s*')
RegexID = re.compile(r'^(?P<ID>[a-zA-Z\-_0-9\.]+)\s*')
RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format(
WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern))
RegexForeignLang = re.compile(r'[A-Z@]')
RegexSpace = re.compile(r'^\s*')
RegexNeutralTone = re.compile(r'[1-5]5')
def do_character_normalization(line):
return unicodedata.normalize('NFKC', line)
def do_prosody_text_normalization(line):
tokens = line.split('\t')
text = tokens[1]
# Remove punctuations
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'|', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace(u'', ' ')
text = text.replace('.', ' ')
text = text.replace('!', ' ')
text = text.replace('?', ' ')
text = text.replace('(', ' ')
text = text.replace(')', ' ')
text = text.replace('[', ' ')
text = text.replace(']', ' ')
text = text.replace('{', ' ')
text = text.replace('}', ' ')
text = text.replace('~', ' ')
text = text.replace(':', ' ')
text = text.replace(';', ' ')
text = text.replace('+', ' ')
text = text.replace(',', ' ')
# text = text.replace('·', ' ')
text = text.replace('"', ' ')
text = text.replace(
'-',
'') # don't replace by space because compond word like two-year-old
text = text.replace(
"'", '') # don't replace by space because English word like that's
# Replace break
text = text.replace('/', '#2')
text = text.replace('%', '#3')
# Remove useless spaces surround #2 #3 #4
text = re.sub(r'(#\d)[ ]+', r'\1', text)
text = re.sub(r'[ ]+(#\d)', r'\1', text)
# Replace space by #1
text = re.sub('[ ]+', '#1', text)
# Remove break at the end of the text
text = re.sub(r'#\d$', '', text)
# Add #1 between target language and foreign language
text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text)
text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text)
return tokens[0] + '\t' + text
def is_fp_line(line):
fp_category_list = ['FP', 'I', 'N', 'Q']
elements = line.strip().split(' ')
res = True
for ele in elements:
if ele not in fp_category_list:
res = False
break
return res
def format_prosody(src_prosody):
formatted_lines = []
with codecs.open(src_prosody, 'r', 'utf-8') as f:
lines = f.readlines()
fp_enable = is_fp_line(lines[1])
for i in range(0, len(lines)):
line = do_character_normalization(lines[i])
if fp_enable:
if i % 5 == 1 or i % 5 == 2 or i % 5 == 3:
continue
if len(line.strip().split('\t')) == 2:
line = do_prosody_text_normalization(line)
formatted_lines.append(line)
return formatted_lines

View File

@@ -0,0 +1,19 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
class XmlObj:
def __init__(self):
pass
def load(self):
pass
def save(self):
pass
def load_data(self):
pass
def save_data(self):
pass

View File

@@ -0,0 +1,463 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import argparse
import os
import re
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .core.core_types import BreakLevel, Language
from .core.phone_set import PhoneSet
from .core.pos_set import PosSet
from .core.script import Script
from .core.script_item import ScriptItem
from .core.script_sentence import ScriptSentence
from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord
from .core.utils import (RegexForeignLang, RegexID, RegexSentence,
format_prosody)
from .core.utils import RegexNeutralTone # isort:skip
from .core.syllable_formatter import ( # isort:skip
EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip
SichuanSyllableFormatter, # isort:skip
WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip
ZhHKSyllableFormatter) # isort:skip
logging = get_logger()
class TextScriptConvertor:
def __init__(
self,
phoneset_path,
posset_path,
target_lang,
foreign_lang,
f2t_map_path,
s2p_map_path,
m_emo_tag_path,
m_speaker,
):
self.m_f2p_map = {}
self.m_s2p_map = {}
self.m_phoneset = PhoneSet(phoneset_path)
self.m_posset = PosSet(posset_path)
self.m_target_lang = Language.parse(target_lang)
self.m_foreign_lang = Language.parse(foreign_lang)
self.m_emo_tag_path = m_emo_tag_path
self.m_speaker = m_speaker
self.load_f2tmap(f2t_map_path)
self.load_s2pmap(s2p_map_path)
self.m_target_lang_syllable_formatter = self.init_syllable_formatter(
self.m_target_lang)
self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter(
self.m_foreign_lang)
def parse_sentence(self, sentence, line_num):
script_item = ScriptItem(self.m_phoneset, self.m_posset)
script_sentence = ScriptSentence(self.m_phoneset, self.m_posset)
script_item.m_scriptSentence_list.append(script_sentence)
written_sentence = script_sentence.m_writtenSentence
spoken_sentence = script_sentence.m_spokenSentence
position = 0
sentence = sentence.strip()
# Get ID
match = re.search(RegexID, sentence)
if match is None:
logging.error(
'TextScriptConvertor.parse_sentence:invalid line: %s,\
line ID is needed',
line_num,
)
return None
else:
sentence_id = match.group('ID')
script_item.m_id = sentence_id
position += match.end()
prevSpokenWord = SpokenWord()
prevWord = False
lastBreak = False
for m in re.finditer(RegexSentence, sentence[position:]):
if m is None:
logging.error(
'TextScriptConvertor.parse_sentence:\
invalid line: %s, there is no matched pattern',
line_num,
)
return None
if m.group('Word') is not None:
wordName = m.group('Word')
written_word = WrittenWord()
written_word.m_name = wordName
written_sentence.add_host(written_word)
spoken_word = SpokenWord()
spoken_word.m_name = wordName
prevSpokenWord = spoken_word
prevWord = True
lastBreak = False
elif m.group('Break') is not None:
breakText = m.group('BreakLevel')
if len(breakText) == 0:
breakLevel = BreakLevel.L1
else:
breakLevel = BreakLevel.parse(breakText)
if prevWord:
prevSpokenWord.m_breakText = breakText
spoken_sentence.add_host(prevSpokenWord)
if breakLevel != BreakLevel.L1:
spokenMark = SpokenMark()
spokenMark.m_breakLevel = breakLevel
spoken_sentence.add_accompany(spokenMark)
lastBreak = True
elif m.group('PhraseTone') is not None:
pass
elif m.group('POS') is not None:
POSClass = m.group('POSClass')
if prevWord:
prevSpokenWord.m_pos = POSClass
prevWord = False
elif m.group('Mark') is not None:
markText = m.group('Mark')
writtenMark = WrittenMark()
writtenMark.m_punctuation = markText
written_sentence.add_accompany(writtenMark)
else:
logging.error(
'TextScriptConvertor.parse_sentence:\
invalid line: %s, matched pattern is unrecognized',
line_num,
)
return None
if not lastBreak:
prevSpokenWord.m_breakText = '4'
spoken_sentence.add_host(prevSpokenWord)
spoken_word_cnt = len(spoken_sentence.m_spoken_word_list)
spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list)
if (spoken_word_cnt > 0
and spoken_sentence.m_align_list[spoken_word_cnt - 1]
== spoken_mark_cnt):
spokenMark = SpokenMark()
spokenMark.m_breakLevel = BreakLevel.L4
spoken_sentence.add_accompany(spokenMark)
written_sentence.build_sequence()
spoken_sentence.build_sequence()
written_sentence.build_text()
spoken_sentence.build_text()
script_sentence.m_text = written_sentence.m_text
script_item.m_text = written_sentence.m_text
return script_item
def format_syllable(self, pron, syllable_list):
isForeign = RegexForeignLang.search(pron) is not None
if self.m_foreign_lang_syllable_formatter is not None and isForeign:
return self.m_foreign_lang_syllable_formatter.format(
self.m_phoneset, pron, syllable_list)
else:
return self.m_target_lang_syllable_formatter.format(
self.m_phoneset, pron, syllable_list)
def get_word_prons(self, pronText):
prons = pronText.split('/')
res = []
for pron in prons:
if re.search(RegexForeignLang, pron):
res.append(pron.strip())
else:
res.extend(pron.strip().split(' '))
return res
def is_erhuayin(self, pron):
pron = RegexNeutralTone.sub('5', pron)
pron = pron[:-1]
return pron[-1] == 'r' and pron != 'er'
def parse_pronunciation(self, script_item, pronunciation, line_num):
spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence
wordProns = self.get_word_prons(pronunciation)
wordIndex = 0
pronIndex = 0
succeed = True
while pronIndex < len(wordProns):
language = Language.Neutral
syllable_list = []
pron = wordProns[pronIndex].strip()
succeed = self.format_syllable(pron, syllable_list)
if not succeed:
logging.error(
'TextScriptConvertor.parse_pronunciation:\
invalid line: %s, error pronunciation: %s,\
syllable format error',
line_num,
pron,
)
return False
language = syllable_list[0].m_language
if wordIndex < len(spoken_sentence.m_spoken_word_list):
if language in [Language.EnGB, Language.EnUS]:
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
elif language in [
Language.ZhCN,
Language.PinYin,
Language.ZhHK,
Language.WuuShanghai,
Language.Sichuan,
]:
charCount = len(
spoken_sentence.m_spoken_word_list[wordIndex].m_name)
if (language in [
Language.ZhCN, Language.PinYin, Language.Sichuan
] and self.is_erhuayin(pron) and '' in spoken_sentence.
m_spoken_word_list[wordIndex].m_name):
spoken_sentence.m_spoken_word_list[
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
wordIndex].m_name.replace('', '')
charCount -= 1
if charCount == 1:
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
else:
# FIXME(Jin): Just skip the first char then match the rest char.
i = 1
while i >= 1 and i < charCount:
pronIndex += 1
if pronIndex < len(wordProns):
pron = wordProns[pronIndex].strip()
succeed = self.format_syllable(
pron, syllable_list)
if not succeed:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, syllable format error',
line_num,
pron,
)
return False
if (language in [
Language.ZhCN,
Language.PinYin,
Language.Sichuan,
] and self.is_erhuayin(pron)
and '' in spoken_sentence.
m_spoken_word_list[wordIndex].m_name):
spoken_sentence.m_spoken_word_list[
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
wordIndex].m_name.replace('', '')
charCount -= 1
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, Word count mismatch with Pron count',
line_num,
pron,
)
return False
i += 1
spoken_sentence.m_spoken_word_list[
wordIndex].m_syllable_list.extend(syllable_list)
wordIndex += 1
pronIndex += 1
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
unsupported language: %s',
line_num,
language.name,
)
return False
else:
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, word index is out of range',
line_num,
pron,
)
return False
if pronIndex != len(wordProns):
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, pron count mismatch with word count',
line_num,
pron,
)
return False
if wordIndex != len(spoken_sentence.m_spoken_word_list):
logging.error(
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
error pronunciation: %s, word count mismatch with word index',
line_num,
pron,
)
return False
return True
def load_f2tmap(self, file_path):
with open(file_path, 'r') as f:
for line in f.readlines():
line = line.strip()
elements = line.split('\t')
if len(elements) != 2:
logging.error(
'TextScriptConvertor.LoadF2TMap: invalid line: %s',
line)
continue
key = elements[0]
value = elements[1]
value_list = value.split(' ')
if key in self.m_f2p_map:
logging.error(
'TextScriptConvertor.LoadF2TMap: duplicate key: %s',
key)
self.m_f2p_map[key] = value_list
def load_s2pmap(self, file_path):
with open(file_path, 'r') as f:
for line in f.readlines():
line = line.strip()
elements = line.split('\t')
if len(elements) != 2:
logging.error(
'TextScriptConvertor.LoadS2PMap: invalid line: %s',
line)
continue
key = elements[0]
value = elements[1]
if key in self.m_s2p_map:
logging.error(
'TextScriptConvertor.LoadS2PMap: duplicate key: %s',
key)
self.m_s2p_map[key] = value
def init_syllable_formatter(self, targetLang):
if targetLang == Language.ZhCN:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty'
)
return None
return ZhCNSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.PinYin:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty'
)
return None
return PinYinSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.ZhHK:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty'
)
return None
return ZhHKSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.WuuShanghai:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty'
)
return None
return WuuShanghaiSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.Sichuan:
if len(self.m_s2p_map) == 0:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty'
)
return None
return SichuanSyllableFormatter(self.m_s2p_map)
elif targetLang == Language.EnGB:
formatter = EnXXSyllableFormatter(Language.EnGB)
if len(self.m_f2p_map) != 0:
formatter.m_f2t_map = self.m_f2p_map
return formatter
elif targetLang == Language.EnUS:
formatter = EnXXSyllableFormatter(Language.EnUS)
if len(self.m_f2p_map) != 0:
formatter.m_f2t_map = self.m_f2p_map
return formatter
else:
logging.error(
'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s',
targetLang,
)
return None
def process(self, textScriptPath, outputXMLPath, outputMetafile):
script = Script(self.m_phoneset, self.m_posset)
formatted_lines = format_prosody(textScriptPath)
line_num = 0
for line in tqdm(formatted_lines):
if line_num % 2 == 0:
sentence = line.strip()
item = self.parse_sentence(sentence, line_num)
else:
if item is not None:
pronunciation = line.strip()
res = self.parse_pronunciation(item, pronunciation,
line_num)
if res:
script.m_items.append(item)
line_num += 1
script.save(outputXMLPath)
logging.info('TextScriptConvertor.process:\nSave script to: %s',
outputXMLPath)
meta_lines = script.save_meta_file()
emo = 'emotion_neutral'
speaker = self.m_speaker
meta_lines_tagged = []
for line in meta_lines:
line_id, line_text = line.split('\t')
syll_items = line_text.split(' ')
syll_items_tagged = []
for syll_item in syll_items:
syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}'
syll_items_tagged.append(syll_item_tagged)
meta_lines_tagged.append(line_id + '\t'
+ ' '.join(syll_items_tagged))
with open(outputMetafile, 'w') as f:
for line in meta_lines_tagged:
f.write(line + '\n')
logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
outputMetafile)

View File

@@ -0,0 +1,562 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.nn.functional as F
from modelscope.models.audio.tts.kantts.models.utils import \
get_mask_from_lengths
from modelscope.models.audio.tts.kantts.utils.audio_torch import (
MelSpectrogram, stft)
class MelReconLoss(torch.nn.Module):
def __init__(self, loss_type='mae'):
super(MelReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == 'mae':
self.criterion = torch.nn.L1Loss(reduction='none')
elif loss_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='none')
else:
raise ValueError('Unknown loss type: {}'.format(loss_type))
def forward(self,
output_lengths,
mel_targets,
dec_outputs,
postnet_outputs=None):
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1))
output_masks = ~output_masks
valid_outputs = output_masks.sum()
mel_loss_ = torch.sum(
self.criterion(mel_targets, dec_outputs)
* output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
if postnet_outputs is not None:
mel_loss = torch.sum(
self.criterion(mel_targets, postnet_outputs)
* output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
else:
mel_loss = 0.0
return mel_loss_, mel_loss
class ProsodyReconLoss(torch.nn.Module):
def __init__(self, loss_type='mae'):
super(ProsodyReconLoss, self).__init__()
self.loss_type = loss_type
if loss_type == 'mae':
self.criterion = torch.nn.L1Loss(reduction='none')
elif loss_type == 'mse':
self.criterion = torch.nn.MSELoss(reduction='none')
else:
raise ValueError('Unknown loss type: {}'.format(loss_type))
def forward(
self,
input_lengths,
duration_targets,
pitch_targets,
energy_targets,
log_duration_predictions,
pitch_predictions,
energy_predictions,
):
input_masks = get_mask_from_lengths(
input_lengths, max_len=duration_targets.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
dur_loss = (
torch.sum(
self.criterion(
torch.log(duration_targets.float() + 1),
log_duration_predictions) * input_masks) / valid_inputs)
pitch_loss = (
torch.sum(
self.criterion(pitch_targets, pitch_predictions) * input_masks)
/ valid_inputs)
energy_loss = (
torch.sum(
self.criterion(energy_targets, energy_predictions)
* input_masks) / valid_inputs)
return dur_loss, pitch_loss, energy_loss
class FpCELoss(torch.nn.Module):
def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]):
super(FpCELoss, self).__init__()
self.loss_type = loss_type
weight_ce = torch.FloatTensor(weight).cuda()
self.criterion = torch.nn.CrossEntropyLoss(
weight=weight_ce, reduction='none')
def forward(self, input_lengths, fp_pd, fp_label):
input_masks = get_mask_from_lengths(
input_lengths, max_len=fp_label.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
fp_loss = (
torch.sum(
self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
/ valid_inputs)
return fp_loss
class GeneratorAdversarialLoss(torch.nn.Module):
"""Generator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type='mse',
):
"""Initialize GeneratorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
if loss_type == 'mse':
self.criterion = self._mse_loss
else:
self.criterion = self._hinge_loss
def forward(self, outputs):
"""Calcualate generator adversarial loss.
Args:
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs.
Returns:
Tensor: Generator adversarial loss value.
"""
if isinstance(outputs, (tuple, list)):
adv_loss = 0.0
for i, outputs_ in enumerate(outputs):
adv_loss += self.criterion(outputs_)
if self.average_by_discriminators:
adv_loss /= i + 1
else:
adv_loss = self.criterion(outputs)
return adv_loss
def _mse_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _hinge_loss(self, x):
return -x.mean()
class DiscriminatorAdversarialLoss(torch.nn.Module):
"""Discriminator adversarial loss module."""
def __init__(
self,
average_by_discriminators=True,
loss_type='mse',
):
"""Initialize DiscriminatorAversarialLoss module."""
super().__init__()
self.average_by_discriminators = average_by_discriminators
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
if loss_type == 'mse':
self.fake_criterion = self._mse_fake_loss
self.real_criterion = self._mse_real_loss
else:
self.fake_criterion = self._hinge_fake_loss
self.real_criterion = self._hinge_real_loss
def forward(self, outputs_hat, outputs):
"""Calcualate discriminator adversarial loss.
Args:
outputs_hat (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from generator outputs.
outputs (Tensor or list): Discriminator outputs or list of
discriminator outputs calculated from groundtruth.
Returns:
Tensor: Discriminator real loss value.
Tensor: Discriminator fake loss value.
"""
if isinstance(outputs, (tuple, list)):
real_loss = 0.0
fake_loss = 0.0
for i, (outputs_hat_,
outputs_) in enumerate(zip(outputs_hat, outputs)):
if isinstance(outputs_hat_, (tuple, list)):
# NOTE(kan-bayashi): case including feature maps
outputs_hat_ = outputs_hat_[-1]
outputs_ = outputs_[-1]
real_loss += self.real_criterion(outputs_)
fake_loss += self.fake_criterion(outputs_hat_)
if self.average_by_discriminators:
fake_loss /= i + 1
real_loss /= i + 1
else:
real_loss = self.real_criterion(outputs)
fake_loss = self.fake_criterion(outputs_hat)
return real_loss, fake_loss
def _mse_real_loss(self, x):
return F.mse_loss(x, x.new_ones(x.size()))
def _mse_fake_loss(self, x):
return F.mse_loss(x, x.new_zeros(x.size()))
def _hinge_real_loss(self, x):
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
def _hinge_fake_loss(self, x):
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
class FeatureMatchLoss(torch.nn.Module):
"""Feature matching loss module."""
def __init__(
self,
average_by_layers=True,
average_by_discriminators=True,
):
"""Initialize FeatureMatchLoss module."""
super().__init__()
self.average_by_layers = average_by_layers
self.average_by_discriminators = average_by_discriminators
def forward(self, feats_hat, feats):
"""Calcualate feature matching loss.
Args:
feats_hat (list): List of list of discriminator outputs
calcuated from generater outputs.
feats (list): List of list of discriminator outputs
calcuated from groundtruth.
Returns:
Tensor: Feature matching loss value.
"""
feat_match_loss = 0.0
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
feat_match_loss_ = 0.0
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
if self.average_by_layers:
feat_match_loss_ /= j + 1
feat_match_loss += feat_match_loss_
if self.average_by_discriminators:
feat_match_loss /= i + 1
return feat_match_loss
class MelSpectrogramLoss(torch.nn.Module):
"""Mel-spectrogram loss."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window='hann',
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
):
"""Initialize Mel-spectrogram loss."""
super().__init__()
self.mel_spectrogram = MelSpectrogram(
fs=fs,
fft_size=fft_size,
hop_size=hop_size,
win_length=win_length,
window=window,
num_mels=num_mels,
fmin=fmin,
fmax=fmax,
center=center,
normalized=normalized,
onesided=onesided,
eps=eps,
log_base=log_base,
)
def forward(self, y_hat, y):
"""Calculate Mel-spectrogram loss.
Args:
y_hat (Tensor): Generated single tensor (B, 1, T).
y (Tensor): Groundtruth single tensor (B, 1, T).
Returns:
Tensor: Mel-spectrogram loss value.
"""
mel_hat = self.mel_spectrogram(y_hat)
mel = self.mel_spectrogram(y)
mel_loss = F.l1_loss(mel_hat, mel)
return mel_loss
class SpectralConvergenceLoss(torch.nn.Module):
"""Spectral convergence loss module."""
def __init__(self):
"""Initilize spectral convergence loss module."""
super(SpectralConvergenceLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Spectral convergence loss value.
"""
return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro')
class LogSTFTMagnitudeLoss(torch.nn.Module):
"""Log STFT magnitude loss module."""
def __init__(self):
"""Initilize los STFT magnitude loss module."""
super(LogSTFTMagnitudeLoss, self).__init__()
def forward(self, x_mag, y_mag):
"""Calculate forward propagation.
Args:
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
Returns:
Tensor: Log STFT magnitude loss value.
"""
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
class STFTLoss(torch.nn.Module):
"""STFT loss module."""
def __init__(self,
fft_size=1024,
shift_size=120,
win_length=600,
window='hann_window'):
"""Initialize STFT loss module."""
super(STFTLoss, self).__init__()
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.spectral_convergence_loss = SpectralConvergenceLoss()
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
# NOTE(kan-bayashi): Use register_buffer to fix #223
self.register_buffer('window', getattr(torch, window)(win_length))
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T).
y (Tensor): Groundtruth signal (B, T).
Returns:
Tensor: Spectral convergence loss value.
Tensor: Log STFT magnitude loss value.
"""
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
self.window)
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
self.window)
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
return sc_loss, mag_loss
class MultiResolutionSTFTLoss(torch.nn.Module):
"""Multi resolution STFT loss module."""
def __init__(
self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window='hann_window',
):
"""Initialize Multi resolution STFT loss module.
Args:
fft_sizes (list): List of FFT sizes.
hop_sizes (list): List of hop sizes.
win_lengths (list): List of window lengths.
window (str): Window function type.
"""
super(MultiResolutionSTFTLoss, self).__init__()
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
self.stft_losses = torch.nn.ModuleList()
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
def forward(self, x, y):
"""Calculate forward propagation.
Args:
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
Returns:
Tensor: Multi resolution spectral convergence loss value.
Tensor: Multi resolution log STFT magnitude loss value.
"""
if len(x.shape) == 3:
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
sc_loss = 0.0
mag_loss = 0.0
for f in self.stft_losses:
sc_l, mag_l = f(x, y)
sc_loss += sc_l
mag_loss += mag_l
sc_loss /= len(self.stft_losses)
mag_loss /= len(self.stft_losses)
return sc_loss, mag_loss
class SeqCELoss(torch.nn.Module):
def __init__(self, loss_type='ce'):
super(SeqCELoss, self).__init__()
self.loss_type = loss_type
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
def forward(self, logits, targets, masks):
loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)),
targets.contiguous().view(-1))
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
masks = masks.contiguous().view(-1)
loss = (loss * masks).sum() / masks.sum()
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
return loss, err
class AttentionBinarizationLoss(torch.nn.Module):
def __init__(self, start_epoch=0, warmup_epoch=100):
super(AttentionBinarizationLoss, self).__init__()
self.start_epoch = start_epoch
self.warmup_epoch = warmup_epoch
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
log_sum = torch.log(
torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum()
kl_loss = -log_sum / hard_attention.sum()
if epoch < self.start_epoch:
warmup_ratio = 0
else:
warmup_ratio = min(1.0,
(epoch - self.start_epoch) / self.warmup_epoch)
return kl_loss * warmup_ratio
class AttentionCTCLoss(torch.nn.Module):
def __init__(self, blank_logprob=-1):
super(AttentionCTCLoss, self).__init__()
self.log_softmax = torch.nn.LogSoftmax(dim=3)
self.blank_logprob = blank_logprob
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
def forward(self, attn_logprob, in_lens, out_lens):
key_lens = in_lens
query_lens = out_lens
attn_logprob_padded = F.pad(
input=attn_logprob,
pad=(1, 0, 0, 0, 0, 0, 0, 0),
value=self.blank_logprob)
cost_total = 0.0
for bid in range(attn_logprob.shape[0]):
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]
+ 1]
curr_logprob = self.log_softmax(curr_logprob[None])[0]
ctc_cost = self.CTCLoss(
curr_logprob,
target_seq,
input_lengths=query_lens[bid:bid + 1],
target_lengths=key_lens[bid:bid + 1],
)
cost_total += ctc_cost
cost = cost_total / attn_logprob.shape[0]
return cost
loss_dict = {
'generator_adv_loss': GeneratorAdversarialLoss,
'discriminator_adv_loss': DiscriminatorAdversarialLoss,
'stft_loss': MultiResolutionSTFTLoss,
'mel_loss': MelSpectrogramLoss,
'subband_stft_loss': MultiResolutionSTFTLoss,
'feat_match_loss': FeatureMatchLoss,
'MelReconLoss': MelReconLoss,
'ProsodyReconLoss': ProsodyReconLoss,
'SeqCELoss': SeqCELoss,
'AttentionBinarizationLoss': AttentionBinarizationLoss,
'AttentionCTCLoss': AttentionCTCLoss,
'FpCELoss': FpCELoss,
}
def criterion_builder(config, device='cpu'):
"""Criterion builder.
Args:
config (dict): Config dictionary.
Returns:
criterion (dict): Loss dictionary
"""
criterion = {}
for key, value in config['Loss'].items():
if key in loss_dict:
if value['enable']:
criterion[key] = loss_dict[key](
**value.get('params', {})).to(device)
setattr(criterion[key], 'weights', value.get('weights', 1.0))
else:
raise NotImplementedError('{} is not implemented'.format(key))
return criterion

View File

@@ -0,0 +1,44 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
class FindLR(_LRScheduler):
"""
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
"""
def __init__(self, optimizer, max_steps, max_lr=10):
self.max_steps = max_steps
self.max_lr = max_lr
super().__init__(optimizer)
def get_lr(self):
return [
base_lr * ((self.max_lr / base_lr)**(
self.last_epoch / # noqa W504
(self.max_steps - 1))) for base_lr in self.base_lrs
]
class NoamLR(_LRScheduler):
"""
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
to the inverse square root of the step number, scaled by the inverse square root of the
dimensionality of the model. Time will tell if this is just madness or it's actually important.
Parameters
----------
warmup_steps: ``int``, required.
The number of steps to linearly increase the learning rate.
"""
def __init__(self, optimizer, warmup_steps):
self.warmup_steps = warmup_steps
super().__init__(optimizer)
def get_lr(self):
last_epoch = max(1, self.last_epoch)
scale = self.warmup_steps**0.5 * min(
last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5))
return [base_lr * scale for base_lr in self.base_lrs]

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,188 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from distutils.version import LooseVersion
import librosa
import torch
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
if is_pytorch_17plus:
x_stft = torch.stft(
x, fft_size, hop_size, win_length, window, return_complex=False)
else:
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
def dynamic_range_decompression_torch(x, C=1):
return torch.pow(10.0, x * 0.05) / C
def spectral_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
if symmetric:
return torch.clamp(
2 * norm_abs_value * ((output - min_level_db) / # noqa W504
(-min_level_db)) - norm_abs_value,
min=-norm_abs_value,
max=norm_abs_value)
else:
return torch.clamp(
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
min=0.0,
max=norm_abs_value)
def spectral_de_normalize_torch(
magnitudes,
min_level_db=-100.0,
ref_level_db=20.0,
norm_abs_value=4.0,
symmetric=True,
):
if symmetric:
magnitudes = torch.clamp(
magnitudes, min=-norm_abs_value, max=norm_abs_value)
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
2 * norm_abs_value) + min_level_db
else:
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
magnitudes = (magnitudes) * (-min_level_db) / (
norm_abs_value) + min_level_db
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
return output
class MelSpectrogram(torch.nn.Module):
"""Calculate Mel-spectrogram."""
def __init__(
self,
fs=22050,
fft_size=1024,
hop_size=256,
win_length=None,
window='hann',
num_mels=80,
fmin=80,
fmax=7600,
center=True,
normalized=False,
onesided=True,
eps=1e-10,
log_base=10.0,
pad_mode='constant',
):
"""Initialize MelSpectrogram module."""
super().__init__()
self.fft_size = fft_size
if win_length is None:
self.win_length = fft_size
else:
self.win_length = win_length
self.hop_size = hop_size
self.center = center
self.normalized = normalized
self.onesided = onesided
if window is not None and not hasattr(torch, f'{window}_window'):
raise ValueError(f'{window} window is not implemented')
self.window = window
self.eps = eps
self.pad_mode = pad_mode
fmin = 0 if fmin is None else fmin
fmax = fs / 2 if fmax is None else fmax
melmat = librosa.filters.mel(
sr=fs,
n_fft=fft_size,
n_mels=num_mels,
fmin=fmin,
fmax=fmax,
)
self.register_buffer('melmat', torch.from_numpy(melmat.T).float())
self.stft_params = {
'n_fft': self.fft_size,
'win_length': self.win_length,
'hop_length': self.hop_size,
'center': self.center,
'normalized': self.normalized,
'onesided': self.onesided,
'pad_mode': self.pad_mode,
}
if is_pytorch_17plus:
self.stft_params['return_complex'] = False
self.log_base = log_base
if self.log_base is None:
self.log = torch.log
elif self.log_base == 2.0:
self.log = torch.log2
elif self.log_base == 10.0:
self.log = torch.log10
else:
raise ValueError(f'log_base: {log_base} is not supported.')
def forward(self, x):
"""Calculate Mel-spectrogram.
Args:
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
Returns:
Tensor: Mel-spectrogram (B, #mels, #frames).
"""
if x.dim() == 3:
# (B, C, T) -> (B*C, T)
x = x.reshape(-1, x.size(2))
if self.window is not None:
window_func = getattr(torch, f'{self.window}_window')
window = window_func(
self.win_length, dtype=x.dtype, device=x.device)
else:
window = None
x_stft = torch.stft(x, window=window, **self.stft_params)
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
x_stft = x_stft.transpose(1, 2)
x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
x_mel = torch.matmul(x_amp, self.melmat)
x_mel = torch.clamp(x_mel, min=self.eps)
x_mel = spectral_normalize_torch(x_mel)
# return self.log(x_mel).transpose(1, 2)
return x_mel.transpose(1, 2)

View File

@@ -0,0 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import ttsfrd
def text_to_mit_symbols(texts, resources_dir, speaker):
fe = ttsfrd.TtsFrontendEngine()
fe.initialize(resources_dir)
fe.set_lang_type('Zh-CN')
symbols_lst = []
for idx, text in enumerate(texts):
text = text.strip()
res = fe.gen_tacotron_symbols(text)
res = res.replace('F7', speaker)
sentences = res.split('\n')
for sentence in sentences:
arr = sentence.split('\t')
# skip the empty line
if len(arr) != 2:
continue
sub_index, symbols = sentence.split('\t')
symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols)
symbols_lst.append(symbol_str)
return symbols_lst

View File

@@ -21,24 +21,21 @@ _whitespace_re = re.compile(r'\s+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [
(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
for x in [('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'), ]] # yapf:disable
for x in [('mrs', 'misess'), ('mr', 'mister'), (
'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), (
'jr',
'junior'), ('maj', 'major'), ('gen', 'general'), (
'drs', 'doctors'), ('rev', 'reverend'), (
'lt',
'lieutenant'), ('hon', 'honorable'), (
'sgt',
'sergeant'), ('capt', 'captain'), (
'esq',
'esquire'), ('ltd',
'limited'), ('col',
'colonel'), ('ft',
'fort')]
]
def expand_abbreviations(text):
@@ -64,14 +61,14 @@ def convert_to_ascii(text):
def basic_cleaners(text):
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
text = lowercase(text)
text = collapse_whitespace(text)
return text
def transliteration_cleaners(text):
'''Pipeline for non-English text that transliterates to ASCII.'''
"""Pipeline for non-English text that transliterates to ASCII."""
text = convert_to_ascii(text)
text = lowercase(text)
text = collapse_whitespace(text)
@@ -79,7 +76,7 @@ def transliteration_cleaners(text):
def english_cleaners(text):
'''Pipeline for English text, including number and abbreviation expansion.'''
"""Pipeline for English text, including number and abbreviation expansion."""
text = convert_to_ascii(text)
text = lowercase(text)
text = expand_numbers(text)

View File

@@ -0,0 +1,37 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
emotion_types = [
'emotion_none',
'emotion_neutral',
'emotion_angry',
'emotion_disgust',
'emotion_fear',
'emotion_happy',
'emotion_sad',
'emotion_surprise',
'emotion_calm',
'emotion_gentle',
'emotion_relax',
'emotion_lyrical',
'emotion_serious',
'emotion_disgruntled',
'emotion_satisfied',
'emotion_disappointed',
'emotion_excited',
'emotion_anxiety',
'emotion_jealousy',
'emotion_hate',
'emotion_pity',
'emotion_pleasure',
'emotion_arousal',
'emotion_dominance',
'emotion_placeholder1',
'emotion_placeholder2',
'emotion_placeholder3',
'emotion_placeholder4',
'emotion_placeholder5',
'emotion_placeholder6',
'emotion_placeholder7',
'emotion_placeholder8',
'emotion_placeholder9',
]

View File

@@ -0,0 +1,88 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import xml.etree.ElementTree as ET
from modelscope.models.audio.tts.kantts.preprocess.languages import languages
from modelscope.utils.logger import get_logger
logging = get_logger()
syllable_flags = [
's_begin',
's_end',
's_none',
's_both',
's_middle',
]
word_segments = [
'word_begin',
'word_end',
'word_middle',
'word_both',
'word_none',
]
def parse_phoneset(phoneset_file):
"""Parse a phoneset file and return a list of symbols.
Args:
phoneset_file (str): Path to the phoneset file.
Returns:
list: A list of phones.
"""
ns = '{http://schemas.alibaba-inc.com/tts}'
phone_lst = []
phoneset_root = ET.parse(phoneset_file).getroot()
for phone_node in phoneset_root.findall(ns + 'phone'):
phone_lst.append(phone_node.find(ns + 'name').text)
for i in range(1, 5):
phone_lst.append('#{}'.format(i))
return phone_lst
def parse_tonelist(tonelist_file):
"""Parse a tonelist file and return a list of tones.
Args:
tonelist_file (str): Path to the tonelist file.
Returns:
dict: A dictionary of tones.
"""
tone_lst = []
with open(tonelist_file, 'r') as f:
lines = f.readlines()
for line in lines:
tone = line.strip()
if tone != '':
tone_lst.append('tone{}'.format(tone))
else:
tone_lst.append('tone_none')
return tone_lst
def get_language_symbols(language, language_dir):
"""Get symbols of a language.
Args:
language (str): Language name.
"""
language_dict = languages.get(language, None)
if language_dict is None:
logging.error('Language %s not supported. Using PinYin as default',
language)
language_dict = languages['PinYin']
language = 'PinYin'
language_dir = os.path.join(language_dir, language)
phoneset_file = os.path.join(language_dir, language_dict['phoneset_path'])
tonelist_file = os.path.join(language_dir, language_dict['tonelist_path'])
phones = parse_phoneset(phoneset_file)
tones = parse_tonelist(tonelist_file)
return phones, tones, syllable_flags, word_segments

View File

@@ -1,15 +1,15 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import abc
import codecs
import os
import re
import shutil
import json
import numpy as np
from . import cleaners as cleaners
from .emotion_types import emotion_types
from .lang_symbols import get_language_symbols
# Regular expression matching text enclosed in curly braces:
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
@@ -19,18 +19,37 @@ def _clean_text(text, cleaner_names):
for name in cleaner_names:
cleaner = getattr(cleaners, name)
if not cleaner:
raise Exception(
'modelscope error: configuration cleaner unknown: %s' % name)
raise Exception('Unknown cleaner: %s' % name)
text = cleaner(text)
return text
def get_fpdict(config):
# eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
en_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {en_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
a_sy = '{ga$tone5$s_begin$word_begin$emotion_neutral$F7} {a_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
e_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {e_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
ling_unit = KanTtsLinguisticUnit(config)
en_lings = ling_unit.encode_symbol_sequence(en_sy)
a_lings = ling_unit.encode_symbol_sequence(a_sy)
e_lings = ling_unit.encode_symbol_sequence(e_sy)
en_ling = np.stack(en_lings, axis=1)[:3, :4]
a_ling = np.stack(a_lings, axis=1)[:3, :4]
e_ling = np.stack(e_lings, axis=1)[:3, :4]
fp_dict = {1: a_ling, 2: en_ling, 3: e_ling}
return fp_dict
class LinguisticBaseUnit(abc.ABC):
def set_config_params(self, config_params):
self.config_params = config_params
def save(self, config, config_name, path):
"""Save config to file"""
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
@@ -39,22 +58,37 @@ class LinguisticBaseUnit(abc.ABC):
class KanTtsLinguisticUnit(LinguisticBaseUnit):
def __init__(self, config, path, has_mask=True):
def __init__(self, config, lang_dir=None):
super(KanTtsLinguisticUnit, self).__init__()
# special symbol
self._pad = '_'
self._eos = '~'
self._mask = '@[MASK]'
self._has_mask = has_mask
self._unit_config = config
self._path = path
self.unit_config = config['linguistic_unit']
self.has_mask = self.unit_config.get('has_mask', True)
self.lang_type = self.unit_config.get('language', 'PinYin')
(
self.lang_phones,
self.lang_tones,
self.lang_syllable_flags,
self.lang_word_segments,
) = get_language_symbols(self.lang_type, lang_dir)
self._cleaner_names = [
x.strip() for x in self._unit_config['cleaners'].split(',')
x.strip() for x in self.unit_config['cleaners'].split(',')
]
self._lfeat_type_list = self._unit_config['lfeat_type_list'].strip(
).split(',')
_lfeat_type_list = self.unit_config['lfeat_type_list'].strip().split(
',')
self._lfeat_type_list = _lfeat_type_list
self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get(
'FP', False)
if self.fp_enable:
self._fpadd_lfeat_type_list = [
_lfeat_type_list[0], _lfeat_type_list[4]
]
self.build()
@@ -79,19 +113,13 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# sy sub-unit
_characters = ''
_ch_symbols = []
sy_path = os.path.join(self._path, self._unit_config['sy'])
f = codecs.open(sy_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_symbols.append(line)
_arpabet = ['@' + s for s in _ch_symbols]
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
_arpabet = ['@' + s for s in self.lang_phones]
# Export all symbols:
self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
if self._has_mask:
if self.has_mask:
self.sy.append(self._mask)
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
@@ -101,17 +129,10 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# tone sub-unit
_characters = ''
_ch_tones = []
tone_path = os.path.join(self._path, self._unit_config['tone'])
f = codecs.open(tone_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_tones.append(line)
# Export all tones:
self.tone = list(_characters) + _ch_tones + [self._pad, self._eos]
if self._has_mask:
self.tone = (
list(_characters) + self.lang_tones + [self._pad, self._eos])
if self.has_mask:
self.tone.append(self._mask)
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
@@ -121,20 +142,11 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# syllable flag sub-unit
_characters = ''
_ch_syllable_flags = []
sy_flag_path = os.path.join(self._path,
self._unit_config['syllable_flag'])
f = codecs.open(sy_flag_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_syllable_flags.append(line)
# Export all syllable_flags:
self.syllable_flag = list(_characters) + _ch_syllable_flags + [
self._pad, self._eos
]
if self._has_mask:
self.syllable_flag = (
list(_characters) + self.lang_syllable_flags
+ [self._pad, self._eos])
if self.has_mask:
self.syllable_flag.append(self._mask)
self._syllable_flag_to_id = {
s: i
@@ -150,19 +162,11 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# word segment sub-unit
_characters = ''
_ch_word_segments = []
ws_path = os.path.join(self._path, self._unit_config['word_segment'])
f = codecs.open(ws_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_word_segments.append(line)
# Export all syllable_flags:
self.word_segment = list(_characters) + _ch_word_segments + [
self._pad, self._eos
]
if self._has_mask:
self.word_segment = (
list(_characters) + self.lang_word_segments
+ [self._pad, self._eos])
if self.has_mask:
self.word_segment.append(self._mask)
self._word_segment_to_id = {
s: i
@@ -179,19 +183,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# emotion category sub-unit
_characters = ''
_ch_emo_types = []
emo_path = os.path.join(self._path,
self._unit_config['emo_category'])
f = codecs.open(emo_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_emo_types.append(line)
self.emo_category = list(_characters) + _ch_emo_types + [
self._pad, self._eos
]
if self._has_mask:
self.emo_category = (
list(_characters) + emotion_types + [self._pad, self._eos])
if self.has_mask:
self.emo_category.append(self._mask)
self._emo_category_to_id = {
s: i
@@ -208,20 +202,12 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
# speaker category sub-unit
_characters = ''
_ch_speakers = []
speaker_path = os.path.join(self._path,
self._unit_config['speaker_category'])
f = codecs.open(speaker_path, 'r')
for line in f:
line = line.strip('\r\n')
_ch_speakers.append(line)
_ch_speakers = self.unit_config['speaker_list'].strip().split(',')
# Export all syllable_flags:
self.speaker = list(_characters) + _ch_speakers + [
self._pad, self._eos
]
if self._has_mask:
self.speaker = (
list(_characters) + _ch_speakers + [self._pad, self._eos])
if self.has_mask:
self.speaker.append(self._mask)
self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
@@ -237,8 +223,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
'$')
index = 0
while index < len(lfeat_symbol_separate):
lfeat_symbol_separate[index] = lfeat_symbol_separate[
index] + this_lfeat_symbol[index] + ' '
lfeat_symbol_separate[index] = (
lfeat_symbol_separate[index] + this_lfeat_symbol[index]
+ ' ')
index = index + 1
input_and_label_data = []
@@ -271,9 +258,7 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
elif lfeat_type == 'speaker_category':
s = self.decode_speaker_category(sequence_item)
else:
raise Exception(
'modelscope error: configuration lfeat type(%s) unknown.'
% lfeat_type)
raise Exception('Unknown lfeat type: %s' % lfeat_type)
result.append('%s:%s' % (lfeat_type, s))
return result
@@ -285,8 +270,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
this_lfeat_symbol_format = ''
index = 0
while index < len(this_lfeat_symbol):
this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[
index] + '}' + ' '
this_lfeat_symbol_format = (
this_lfeat_symbol_format + '{' + this_lfeat_symbol[index]
+ '}' + ' ')
index = index + 1
sequence = self.encode_text(this_lfeat_symbol_format,
self._cleaner_names)
@@ -301,9 +287,7 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
elif lfeat_type == 'speaker_category':
sequence = self.encode_speaker_category(this_lfeat_symbol)
else:
raise Exception(
'modelscope error: configuration lfeat type(%s) unknown.'
% lfeat_type)
raise Exception('Unknown lfeat type: %s' % lfeat_type)
return sequence

View File

@@ -1,5 +1,4 @@
# The implementation is adopted from tacotron,
# made publicly available under the MIT License at https://github.com/keithito/tacotron
# Copyright (c) Alibaba, Inc. and its affiliates.
import re

View File

@@ -0,0 +1,26 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import logging
import subprocess
def logging_to_file(log_file):
logger = logging.getLogger()
handler = logging.FileHandler(log_file)
formatter = logging.Formatter(
'%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d:%H:%M:%S',
)
handler.setFormatter(formatter)
logger.addHandler(handler)
logger.setLevel(logging.INFO)
def get_git_revision_short_hash():
return (subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode('ascii').strip())
def get_git_revision_hash():
return subprocess.check_output(['git', 'rev-parse',
'HEAD']).decode('ascii').strip()

View File

@@ -0,0 +1,39 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import matplotlib
matplotlib.use('Agg')
try:
import matplotlib.pyplot as plt
except ImportError:
raise ImportError('Please install matplotlib.')
plt.set_loglevel('info')
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(12, 8))
im = ax.imshow(
spectrogram, aspect='auto', origin='lower', interpolation='none')
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
im = ax.imshow(
alignment, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Input timestep'
if info is not None:
xlabel += '\t' + info
plt.xlabel(xlabel)
plt.ylabel('Output timestep')
fig.canvas.draw()
plt.close()
return fig

View File

@@ -1,238 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import json
import numpy as np
import torch
from torch.utils.data import Dataset
from tqdm import tqdm
from modelscope.utils.logger import get_logger
from .units import KanTtsLinguisticUnit
logger = get_logger()
class KanTtsText2MelDataset(Dataset):
def __init__(self, metadata_filename, config_filename, cache=False):
super(KanTtsText2MelDataset, self).__init__()
self.cache = cache
with open(config_filename, encoding='utf-8') as f:
self._config = json.loads(f.read())
# Load metadata:
self._datadir = os.path.dirname(metadata_filename)
with open(metadata_filename, encoding='utf-8') as f:
self._metadata = [line.strip().split('|') for line in f]
self._length_lst = [int(x[2]) for x in self._metadata]
hours = sum(
self._length_lst) * self._config['audio']['frame_shift_ms'] / (
3600 * 1000)
logger.info('Loaded metadata for %d examples (%.2f hours)' %
(len(self._metadata), hours))
logger.info('Minimum length: %d, Maximum length: %d' %
(min(self._length_lst), max(self._length_lst)))
self.ling_unit = KanTtsLinguisticUnit(config_filename)
self.pad_executor = KanTtsText2MelPad()
self.r = self._config['am']['outputs_per_step']
self.num_mels = self._config['am']['num_mels']
if 'adv' in self._config:
self.feat_window = self._config['adv']['random_window']
else:
self.feat_window = None
logger.info(self.feat_window)
self.data_cache = [
self.cache_load(i) for i in tqdm(range(self.__len__()))
] if self.cache else []
def get_frames_lst(self):
return self._length_lst
def __getitem__(self, index):
if self.cache:
sample = self.data_cache[index]
return sample
return self.cache_load(index)
def cache_load(self, index):
sample = {}
meta = self._metadata[index]
sample['utt_id'] = meta[0]
sample['mel_target'] = np.load(os.path.join(
self._datadir, meta[1]))[:, :self.num_mels]
sample['output_length'] = len(sample['mel_target'])
lfeat_symbol = meta[3]
sample['ling'] = self.ling_unit.encode_symbol_sequence(lfeat_symbol)
sample['duration'] = np.load(os.path.join(self._datadir, meta[4]))
sample['pitch_contour'] = np.load(os.path.join(self._datadir, meta[5]))
sample['energy_contour'] = np.load(
os.path.join(self._datadir, meta[6]))
return sample
def __len__(self):
return len(self._metadata)
def collate_fn(self, batch):
data_dict = {}
max_input_length = max((len(x['ling'][0]) for x in batch))
# pure linguistic info: sy|tone|syllable_flag|word_segment
# sy
lfeat_type = self.ling_unit._lfeat_type_list[0]
inputs_sy = self.pad_executor._prepare_scalar_inputs(
[x['ling'][0] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
# tone
lfeat_type = self.ling_unit._lfeat_type_list[1]
inputs_tone = self.pad_executor._prepare_scalar_inputs(
[x['ling'][1] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
# syllable_flag
lfeat_type = self.ling_unit._lfeat_type_list[2]
inputs_syllable_flag = self.pad_executor._prepare_scalar_inputs(
[x['ling'][2] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
# word_segment
lfeat_type = self.ling_unit._lfeat_type_list[3]
inputs_ws = self.pad_executor._prepare_scalar_inputs(
[x['ling'][3] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
# emotion category
lfeat_type = self.ling_unit._lfeat_type_list[4]
data_dict['input_emotions'] = self.pad_executor._prepare_scalar_inputs(
[x['ling'][4] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
# speaker category
lfeat_type = self.ling_unit._lfeat_type_list[5]
data_dict['input_speakers'] = self.pad_executor._prepare_scalar_inputs(
[x['ling'][5] for x in batch], max_input_length,
self.ling_unit._sub_unit_pad[lfeat_type]).long()
data_dict['input_lings'] = torch.stack(
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
data_dict['valid_input_lengths'] = torch.as_tensor(
[len(x['ling'][0]) - 1 for x in batch], dtype=torch.long
) # There is one '~' in the last of symbol sequence. We put length-1 for calculation.
data_dict['valid_output_lengths'] = torch.as_tensor(
[x['output_length'] for x in batch], dtype=torch.long)
max_output_length = torch.max(data_dict['valid_output_lengths']).item()
max_output_round_length = self.pad_executor._round_up(
max_output_length, self.r)
if self.feat_window is not None:
active_feat_len = np.minimum(max_output_round_length,
self.feat_window)
if active_feat_len < self.feat_window:
max_output_round_length = self.pad_executor._round_up(
self.feat_window, self.r)
active_feat_len = self.feat_window
max_offsets = [x['output_length'] - active_feat_len for x in batch]
feat_offsets = [
np.random.randint(0, np.maximum(1, offset))
for offset in max_offsets
]
feat_offsets = torch.from_numpy(
np.asarray(feat_offsets, dtype=np.int32)).long()
data_dict['feat_offsets'] = feat_offsets
data_dict['mel_targets'] = self.pad_executor._prepare_targets(
[x['mel_target'] for x in batch], max_output_round_length, 0.0)
data_dict['durations'] = self.pad_executor._prepare_durations(
[x['duration'] for x in batch], max_input_length,
max_output_round_length)
data_dict['pitch_contours'] = self.pad_executor._prepare_scalar_inputs(
[x['pitch_contour'] for x in batch], max_input_length,
0.0).float()
data_dict[
'energy_contours'] = self.pad_executor._prepare_scalar_inputs(
[x['energy_contour'] for x in batch], max_input_length,
0.0).float()
data_dict['utt_ids'] = [x['utt_id'] for x in batch]
return data_dict
class KanTtsText2MelPad(object):
def __init__(self):
super(KanTtsText2MelPad, self).__init__()
pass
def _pad1D(self, x, length, pad):
return np.pad(
x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
def _pad2D(self, x, length, pad):
return np.pad(
x, [(0, length - x.shape[0]), (0, 0)],
mode='constant',
constant_values=pad)
def _pad_durations(self, duration, max_in_len, max_out_len):
framenum = np.sum(duration)
symbolnum = duration.shape[0]
if framenum < max_out_len:
padframenum = max_out_len - framenum
duration = np.insert(
duration, symbolnum, values=padframenum, axis=0)
duration = np.insert(
duration,
symbolnum + 1,
values=[0] * (max_in_len - symbolnum - 1),
axis=0)
else:
if symbolnum < max_in_len:
duration = np.insert(
duration,
symbolnum,
values=[0] * (max_in_len - symbolnum),
axis=0)
return duration
def _round_up(self, x, multiple):
remainder = x % multiple
return x if remainder == 0 else x + multiple - remainder
def _prepare_scalar_inputs(self, inputs, max_len, pad):
return torch.from_numpy(
np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
def _prepare_targets(self, targets, max_len, pad):
return torch.from_numpy(
np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
def _prepare_durations(self, durations, max_in_len, max_out_len):
return torch.from_numpy(
np.stack([
self._pad_durations(t, max_in_len, max_out_len)
for t in durations
])).long()

View File

@@ -1,131 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import math
import random
import torch
from torch import distributed as dist
from torch.utils.data import Sampler
class LenSortGroupPoolSampler(Sampler):
def __init__(self, data_source, length_lst, group_size):
super(LenSortGroupPoolSampler, self).__init__(data_source)
self.data_source = data_source
self.length_lst = length_lst
self.group_size = group_size
self.num = len(self.length_lst)
self.buckets = self.num // group_size
def __iter__(self):
def getkey(item):
return item[1]
random_lst = torch.randperm(self.num).tolist()
random_len_lst = [(i, self.length_lst[i]) for i in random_lst]
# Bucket examples based on similar output sequence length for efficiency:
groups = [
random_len_lst[i:i + self.group_size]
for i in range(0, self.num, self.group_size)
]
if (self.num % self.group_size):
groups.append(random_len_lst[self.buckets * self.group_size:-1])
indices = []
for group in groups:
group.sort(key=getkey, reverse=True)
for item in group:
indices.append(item[0])
return iter(indices)
def __len__(self):
return len(self.data_source)
class DistributedLenSortGroupPoolSampler(Sampler):
def __init__(self,
dataset,
length_lst,
group_size,
num_replicas=None,
rank=None,
shuffle=True):
super(DistributedLenSortGroupPoolSampler, self).__init__(dataset)
if num_replicas is None:
if not dist.is_available():
raise RuntimeError(
'modelscope error: Requires distributed package to be available'
)
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError(
'modelscope error: Requires distributed package to be available'
)
rank = dist.get_rank()
self.dataset = dataset
self.length_lst = length_lst
self.group_size = group_size
self.num_replicas = num_replicas
self.rank = rank
self.epoch = 0
self.num_samples = int(
math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas
self.buckets = self.num_samples // group_size
self.shuffle = shuffle
def __iter__(self):
def getkey(item):
return item[1]
# deterministically shuffle based on epoch
g = torch.Generator()
g.manual_seed(self.epoch)
if self.shuffle:
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
indices = list(range(len(self.dataset)))
# add extra samples to make it evenly divisible
indices += indices[:(self.total_size - len(indices))]
assert len(indices) == self.total_size
# subsample
indices = indices[self.rank:self.total_size:self.num_replicas]
assert len(indices) == self.num_samples
random_len_lst = [(i, self.length_lst[i]) for i in indices]
# Bucket examples based on similar output sequence length for efficiency:
groups = [
random_len_lst[i:i + self.group_size]
for i in range(0, self.num_samples, self.group_size)
]
if (self.num_samples % self.group_size):
groups.append(random_len_lst[self.buckets * self.group_size:-1])
new_indices = []
for group in groups:
group.sort(key=getkey, reverse=True)
for item in group:
new_indices.append(item[0])
return iter(new_indices)
def __len__(self):
return self.num_samples
def set_epoch(self, epoch):
self.epoch = epoch

View File

@@ -1,3 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .ling_unit import * # noqa F403

View File

@@ -1,3 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .hifigan import * # noqa F403

View File

@@ -1,238 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from https://github.com/jik876/hifi-gan
from distutils.version import LooseVersion
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
from modelscope.models.audio.tts.models.utils import get_padding, init_weights
from modelscope.utils.logger import get_logger
logger = get_logger()
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B).
"""
if is_pytorch_17plus:
x_stft = torch.stft(
x, fft_size, hop_size, win_length, window, return_complex=False)
else:
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
real = x_stft[..., 0]
imag = x_stft[..., 1]
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
LRELU_SLOPE = 0.1
def get_padding_casual(kernel_size, dilation=1):
return int(kernel_size * dilation - dilation)
class Conv1dCasual(torch.nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True,
padding_mode='zeros'):
super(Conv1dCasual, self).__init__()
self.pad = padding
self.conv1d = weight_norm(
Conv1d(
in_channels,
out_channels,
kernel_size,
stride,
padding=0,
dilation=dilation,
groups=groups,
bias=bias,
padding_mode=padding_mode))
self.conv1d.apply(init_weights)
def forward(self, x): # bdt
# described starting from the last dimension and moving forward.
x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant')
x = self.conv1d(x)
return x
def remove_weight_norm(self):
remove_weight_norm(self.conv1d)
class ConvTranspose1dCausal(torch.nn.Module):
"""CausalConvTranspose1d module with customized initialization."""
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
padding=0):
"""Initialize CausalConvTranspose1d module."""
super(ConvTranspose1dCausal, self).__init__()
self.deconv = weight_norm(
ConvTranspose1d(in_channels, out_channels, kernel_size, stride))
self.stride = stride
self.deconv.apply(init_weights)
self.pad = kernel_size - stride
def forward(self, x):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor (B, in_channels, T_in).
Returns:
Tensor: Output tensor (B, out_channels, T_out).
"""
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
return self.deconv(x)[:, :, :-self.pad]
def remove_weight_norm(self):
remove_weight_norm(self.deconv)
class ResBlock1(torch.nn.Module):
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
self.h = h
self.convs1 = nn.ModuleList([
Conv1dCasual(
channels,
channels,
kernel_size,
1,
dilation=dilation[i],
padding=get_padding_casual(kernel_size, dilation[i]))
for i in range(len(dilation))
])
self.convs2 = nn.ModuleList([
Conv1dCasual(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding_casual(kernel_size, 1))
for i in range(len(dilation))
])
def forward(self, x):
for c1, c2 in zip(self.convs1, self.convs2):
xt = F.leaky_relu(x, LRELU_SLOPE)
xt = c1(xt)
xt = F.leaky_relu(xt, LRELU_SLOPE)
xt = c2(xt)
x = xt + x
return x
def remove_weight_norm(self):
for layer in self.convs1:
layer.remove_weight_norm()
for layer in self.convs2:
layer.remove_weight_norm()
class Generator(torch.nn.Module):
def __init__(self, h):
super(Generator, self).__init__()
self.h = h
self.num_kernels = len(h.resblock_kernel_sizes)
self.num_upsamples = len(h.upsample_rates)
logger.info('num_kernels={}, num_upsamples={}'.format(
self.num_kernels, self.num_upsamples))
self.conv_pre = Conv1dCasual(
80, h.upsample_initial_channel, 7, 1, padding=7 - 1)
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
self.ups = nn.ModuleList()
self.repeat_ups = nn.ModuleList()
for i, (u, k) in enumerate(
zip(h.upsample_rates, h.upsample_kernel_sizes)):
upsample = nn.Sequential(
nn.Upsample(mode='nearest', scale_factor=u),
nn.LeakyReLU(LRELU_SLOPE),
Conv1dCasual(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
kernel_size=7,
stride=1,
padding=7 - 1))
self.repeat_ups.append(upsample)
self.ups.append(
ConvTranspose1dCausal(
h.upsample_initial_channel // (2**i),
h.upsample_initial_channel // (2**(i + 1)),
k,
u,
padding=(k - u) // 2))
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2**(i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock(h, ch, k, d))
self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1)
def forward(self, x):
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = torch.sin(x) + x
# transconv
x1 = F.leaky_relu(x, LRELU_SLOPE)
x1 = self.ups[i](x1)
# repeat
x2 = self.repeat_ups[i](x)
x = x1 + x2
xs = None
for j in range(self.num_kernels):
if xs is None:
xs = self.resblocks[i * self.num_kernels + j](x)
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.conv_post(x)
x = torch.tanh(x)
return x
def remove_weight_norm(self):
logger.info('Removing weight norm...')
for layer in self.ups:
layer.remove_weight_norm()
for layer in self.repeat_ups:
layer[-1].remove_weight_norm()
for layer in self.resblocks:
layer.remove_weight_norm()
self.conv_pre.remove_weight_norm()
self.conv_post.remove_weight_norm()

View File

@@ -1,3 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .kantts_sambert import * # noqa F403

View File

@@ -1,718 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from modelscope.models.audio.tts.models.utils import get_mask_from_lengths
from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor,
VarRnnARPredictor)
from .base import FFTBlock, PNCABlock, Prenet
from .fsmn import FsmnEncoderV2
from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder
class SelfAttentionEncoder(nn.Module):
def __init__(self, n_layer, d_in, d_model, n_head, d_head, d_inner,
dropout, dropout_att, dropout_relu, position_encoder):
super(SelfAttentionEncoder, self).__init__()
self.d_in = d_in
self.d_model = d_model
self.dropout = dropout
d_in_lst = [d_in] + [d_model] * (n_layer - 1)
self.fft = nn.ModuleList([
FFTBlock(d, d_model, n_head, d_head, d_inner, (3, 1), dropout,
dropout_att, dropout_relu) for d in d_in_lst
])
self.ln = nn.LayerNorm(d_model, eps=1e-6)
self.position_enc = position_encoder
def forward(self, input, mask=None, return_attns=False):
input *= self.d_model**0.5
if (isinstance(self.position_enc, SinusoidalPositionEncoder)):
input = self.position_enc(input)
else:
raise NotImplementedError('modelscope error: position_enc invalid')
input = F.dropout(input, p=self.dropout, training=self.training)
enc_slf_attn_list = []
max_len = input.size(1)
if mask is not None:
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
else:
slf_attn_mask = None
enc_output = input
for id, layer in enumerate(self.fft):
enc_output, enc_slf_attn = layer(
enc_output, mask=mask, slf_attn_mask=slf_attn_mask)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
enc_output = self.ln(enc_output)
return enc_output, enc_slf_attn_list
class HybridAttentionDecoder(nn.Module):
def __init__(self, d_in, prenet_units, n_layer, d_model, d_mem, n_head,
d_head, d_inner, dropout, dropout_att, dropout_relu, d_out):
super(HybridAttentionDecoder, self).__init__()
self.d_model = d_model
self.dropout = dropout
self.prenet = Prenet(d_in, prenet_units, d_model)
self.dec_in_proj = nn.Linear(d_model + d_mem, d_model)
self.pnca = nn.ModuleList([
PNCABlock(d_model, d_mem, n_head, d_head, d_inner, (1, 1), dropout,
dropout_att, dropout_relu) for _ in range(n_layer)
])
self.ln = nn.LayerNorm(d_model, eps=1e-6)
self.dec_out_proj = nn.Linear(d_model, d_out)
def reset_state(self):
for layer in self.pnca:
layer.reset_state()
def get_pnca_attn_mask(self,
device,
max_len,
x_band_width,
h_band_width,
mask=None):
if mask is not None:
pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
else:
pnca_attn_mask = None
range_ = torch.arange(max_len).to(device)
x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :]
x_end = (range_ + 1)[None, None, :]
h_start = range_[None, None, :]
h_end = torch.clamp_max(range_ + h_band_width + 1,
max_len + 1)[None, None, :]
pnca_x_attn_mask = ~((x_start <= range_[None, :, None])
& (x_end > range_[None, :, None])).transpose(1, 2) # yapf:disable
pnca_h_attn_mask = ~((h_start <= range_[None, :, None])
& (h_end > range_[None, :, None])).transpose(1, 2) # yapf:disable
if pnca_attn_mask is not None:
pnca_x_attn_mask = (pnca_x_attn_mask | pnca_attn_mask)
pnca_h_attn_mask = (pnca_h_attn_mask | pnca_attn_mask)
pnca_x_attn_mask = pnca_x_attn_mask.masked_fill(
pnca_attn_mask.transpose(1, 2), False)
pnca_h_attn_mask = pnca_h_attn_mask.masked_fill(
pnca_attn_mask.transpose(1, 2), False)
return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask
# must call reset_state before
def forward(self,
input,
memory,
x_band_width,
h_band_width,
mask=None,
return_attns=False):
input = self.prenet(input)
input = torch.cat([memory, input], dim=-1)
input = self.dec_in_proj(input)
if mask is not None:
input = input.masked_fill(mask.unsqueeze(-1), 0)
input *= self.d_model**0.5
input = F.dropout(input, p=self.dropout, training=self.training)
max_len = input.size(1)
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
input.device, max_len, x_band_width, h_band_width, mask)
dec_pnca_attn_x_list = []
dec_pnca_attn_h_list = []
dec_output = input
for id, layer in enumerate(self.pnca):
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
dec_output,
memory,
mask=mask,
pnca_x_attn_mask=pnca_x_attn_mask,
pnca_h_attn_mask=pnca_h_attn_mask)
if return_attns:
dec_pnca_attn_x_list += [dec_pnca_attn_x]
dec_pnca_attn_h_list += [dec_pnca_attn_h]
dec_output = self.ln(dec_output)
dec_output = self.dec_out_proj(dec_output)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
# must call reset_state before when step == 0
def infer(self,
step,
input,
memory,
x_band_width,
h_band_width,
mask=None,
return_attns=False):
max_len = memory.size(1)
input = self.prenet(input)
input = torch.cat([memory[:, step:step + 1, :], input], dim=-1)
input = self.dec_in_proj(input)
input *= self.d_model**0.5
input = F.dropout(input, p=self.dropout, training=self.training)
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
input.device, max_len, x_band_width, h_band_width, mask)
dec_pnca_attn_x_list = []
dec_pnca_attn_h_list = []
dec_output = input
for id, layer in enumerate(self.pnca):
if mask is not None:
mask_step = mask[:, step:step + 1]
else:
mask_step = None
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
dec_output,
memory,
mask=mask_step,
pnca_x_attn_mask=pnca_x_attn_mask[:,
step:step + 1, :(step + 1)],
pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :])
if return_attns:
dec_pnca_attn_x_list += [dec_pnca_attn_x]
dec_pnca_attn_h_list += [dec_pnca_attn_h]
dec_output = self.ln(dec_output)
dec_output = self.dec_out_proj(dec_output)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
class TextFftEncoder(nn.Module):
def __init__(self, config, ling_unit_size):
super(TextFftEncoder, self).__init__()
# linguistic unit lookup table
nb_ling_sy = ling_unit_size['sy']
nb_ling_tone = ling_unit_size['tone']
nb_ling_syllable_flag = ling_unit_size['syllable_flag']
nb_ling_ws = ling_unit_size['word_segment']
max_len = config['am']['max_len']
d_emb = config['am']['embedding_dim']
nb_layers = config['am']['encoder_num_layers']
nb_heads = config['am']['encoder_num_heads']
d_model = config['am']['encoder_num_units']
d_head = d_model // nb_heads
d_inner = config['am']['encoder_ffn_inner_dim']
dropout = config['am']['encoder_dropout']
dropout_attn = config['am']['encoder_attention_dropout']
dropout_relu = config['am']['encoder_relu_dropout']
d_proj = config['am']['encoder_projection_units']
self.d_model = d_model
self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
position_enc = SinusoidalPositionEncoder(max_len, d_emb)
self.ling_enc = SelfAttentionEncoder(nb_layers, d_emb, d_model,
nb_heads, d_head, d_inner,
dropout, dropout_attn,
dropout_relu, position_enc)
self.ling_proj = nn.Linear(d_model, d_proj, bias=False)
def forward(self, inputs_ling, masks=None, return_attns=False):
# Parse inputs_ling_seq
inputs_sy = inputs_ling[:, :, 0]
inputs_tone = inputs_ling[:, :, 1]
inputs_syllable_flag = inputs_ling[:, :, 2]
inputs_ws = inputs_ling[:, :, 3]
# Lookup table
sy_embedding = self.sy_emb(inputs_sy)
tone_embedding = self.tone_emb(inputs_tone)
syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag)
ws_embedding = self.ws_emb(inputs_ws)
ling_embedding = sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding
enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks,
return_attns)
enc_output = self.ling_proj(enc_output)
return enc_output, enc_slf_attn_list
class VarianceAdaptor(nn.Module):
def __init__(self, config):
super(VarianceAdaptor, self).__init__()
input_dim = config['am']['encoder_projection_units'] + config['am'][
'emotion_units'] + config['am']['speaker_units']
filter_size = config['am']['predictor_filter_size']
fsmn_num_layers = config['am']['predictor_fsmn_num_layers']
num_memory_units = config['am']['predictor_num_memory_units']
ffn_inner_dim = config['am']['predictor_ffn_inner_dim']
dropout = config['am']['predictor_dropout']
shift = config['am']['predictor_shift']
lstm_units = config['am']['predictor_lstm_units']
dur_pred_prenet_units = config['am']['dur_pred_prenet_units']
dur_pred_lstm_units = config['am']['dur_pred_lstm_units']
self.pitch_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim, dropout,
shift, lstm_units)
self.energy_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size,
fsmn_num_layers,
num_memory_units,
ffn_inner_dim, dropout,
shift, lstm_units)
self.duration_predictor = VarRnnARPredictor(input_dim,
dur_pred_prenet_units,
dur_pred_lstm_units)
self.length_regulator = LengthRegulator(
config['am']['outputs_per_step'])
self.dur_position_encoder = DurSinusoidalPositionEncoder(
config['am']['encoder_projection_units'],
config['am']['outputs_per_step'])
self.pitch_emb = nn.Conv1d(
1,
config['am']['encoder_projection_units'],
kernel_size=9,
padding=4)
self.energy_emb = nn.Conv1d(
1,
config['am']['encoder_projection_units'],
kernel_size=9,
padding=4)
def forward(self,
inputs_text_embedding,
inputs_emo_embedding,
inputs_spk_embedding,
masks=None,
output_masks=None,
duration_targets=None,
pitch_targets=None,
energy_targets=None):
batch_size = inputs_text_embedding.size(0)
variance_predictor_inputs = torch.cat([
inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding
], dim=-1) # yapf:disable
pitch_predictions = self.pitch_predictor(variance_predictor_inputs,
masks)
energy_predictions = self.energy_predictor(variance_predictor_inputs,
masks)
if pitch_targets is not None:
pitch_embeddings = self.pitch_emb(
pitch_targets.unsqueeze(1)).transpose(1, 2)
else:
pitch_embeddings = self.pitch_emb(
pitch_predictions.unsqueeze(1)).transpose(1, 2)
if energy_targets is not None:
energy_embeddings = self.energy_emb(
energy_targets.unsqueeze(1)).transpose(1, 2)
else:
energy_embeddings = self.energy_emb(
energy_predictions.unsqueeze(1)).transpose(1, 2)
inputs_text_embedding_aug = inputs_text_embedding + pitch_embeddings + energy_embeddings
duration_predictor_cond = torch.cat([
inputs_text_embedding_aug, inputs_spk_embedding,
inputs_emo_embedding
], dim=-1) # yapf:disable
if duration_targets is not None:
duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
inputs_text_embedding.device)
duration_predictor_input = torch.cat([
duration_predictor_go_frame, duration_targets[:, :-1].float()
], dim=-1) # yapf:disable
duration_predictor_input = torch.log(duration_predictor_input + 1)
log_duration_predictions, _ = self.duration_predictor(
duration_predictor_input.unsqueeze(-1),
duration_predictor_cond,
masks=masks)
duration_predictions = torch.exp(log_duration_predictions) - 1
else:
log_duration_predictions = self.duration_predictor.infer(
duration_predictor_cond, masks=masks)
duration_predictions = torch.exp(log_duration_predictions) - 1
if duration_targets is not None:
LR_text_outputs, LR_length_rounded = self.length_regulator(
inputs_text_embedding_aug,
duration_targets,
masks=output_masks)
LR_position_embeddings = self.dur_position_encoder(
duration_targets, masks=output_masks)
LR_emo_outputs, _ = self.length_regulator(
inputs_emo_embedding, duration_targets, masks=output_masks)
LR_spk_outputs, _ = self.length_regulator(
inputs_spk_embedding, duration_targets, masks=output_masks)
else:
LR_text_outputs, LR_length_rounded = self.length_regulator(
inputs_text_embedding_aug,
duration_predictions,
masks=output_masks)
LR_position_embeddings = self.dur_position_encoder(
duration_predictions, masks=output_masks)
LR_emo_outputs, _ = self.length_regulator(
inputs_emo_embedding, duration_predictions, masks=output_masks)
LR_spk_outputs, _ = self.length_regulator(
inputs_spk_embedding, duration_predictions, masks=output_masks)
LR_text_outputs = LR_text_outputs + LR_position_embeddings
return (LR_text_outputs, LR_emo_outputs, LR_spk_outputs,
LR_length_rounded, log_duration_predictions, pitch_predictions,
energy_predictions)
class MelPNCADecoder(nn.Module):
def __init__(self, config):
super(MelPNCADecoder, self).__init__()
prenet_units = config['am']['decoder_prenet_units']
nb_layers = config['am']['decoder_num_layers']
nb_heads = config['am']['decoder_num_heads']
d_model = config['am']['decoder_num_units']
d_head = d_model // nb_heads
d_inner = config['am']['decoder_ffn_inner_dim']
dropout = config['am']['decoder_dropout']
dropout_attn = config['am']['decoder_attention_dropout']
dropout_relu = config['am']['decoder_relu_dropout']
outputs_per_step = config['am']['outputs_per_step']
d_mem = config['am'][
'encoder_projection_units'] * outputs_per_step + config['am'][
'emotion_units'] + config['am']['speaker_units']
d_mel = config['am']['num_mels']
self.d_mel = d_mel
self.r = outputs_per_step
self.nb_layers = nb_layers
self.mel_dec = HybridAttentionDecoder(d_mel, prenet_units, nb_layers,
d_model, d_mem, nb_heads, d_head,
d_inner, dropout, dropout_attn,
dropout_relu,
d_mel * outputs_per_step)
def forward(self,
memory,
x_band_width,
h_band_width,
target=None,
mask=None,
return_attns=False):
batch_size = memory.size(0)
go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device)
if target is not None:
self.mel_dec.reset_state()
input = target[:, self.r - 1::self.r, :]
input = torch.cat([go_frame, input], dim=1)[:, :-1, :]
dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec(
input,
memory,
x_band_width,
h_band_width,
mask=mask,
return_attns=return_attns)
else:
dec_output = []
dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)]
dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)]
self.mel_dec.reset_state()
input = go_frame
for step in range(memory.size(1)):
dec_output_step, dec_pnca_attn_x_step, dec_pnca_attn_h_step = self.mel_dec.infer(
step,
input,
memory,
x_band_width,
h_band_width,
mask=mask,
return_attns=return_attns)
input = dec_output_step[:, :, -self.d_mel:]
dec_output.append(dec_output_step)
for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)):
left = memory.size(1) - pnca_x_attn.size(-1)
if (left > 0):
padding = torch.zeros(
(pnca_x_attn.size(0), 1, left)).to(pnca_x_attn)
pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1)
dec_pnca_attn_x_list[layer_id].append(pnca_x_attn)
dec_pnca_attn_h_list[layer_id].append(pnca_h_attn)
dec_output = torch.cat(dec_output, dim=1)
for layer_id in range(self.nb_layers):
dec_pnca_attn_x_list[layer_id] = torch.cat(
dec_pnca_attn_x_list[layer_id], dim=1)
dec_pnca_attn_h_list[layer_id] = torch.cat(
dec_pnca_attn_h_list[layer_id], dim=1)
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
class PostNet(nn.Module):
def __init__(self, config):
super(PostNet, self).__init__()
self.filter_size = config['am']['postnet_filter_size']
self.fsmn_num_layers = config['am']['postnet_fsmn_num_layers']
self.num_memory_units = config['am']['postnet_num_memory_units']
self.ffn_inner_dim = config['am']['postnet_ffn_inner_dim']
self.dropout = config['am']['postnet_dropout']
self.shift = config['am']['postnet_shift']
self.lstm_units = config['am']['postnet_lstm_units']
self.num_mels = config['am']['num_mels']
self.fsmn = FsmnEncoderV2(self.filter_size, self.fsmn_num_layers,
self.num_mels, self.num_memory_units,
self.ffn_inner_dim, self.dropout, self.shift)
self.lstm = nn.LSTM(
self.num_memory_units,
self.lstm_units,
num_layers=1,
batch_first=True)
self.fc = nn.Linear(self.lstm_units, self.num_mels)
def forward(self, x, mask=None):
postnet_fsmn_output = self.fsmn(x, mask)
# The input can also be a packed variable length sequence,
# here we just omit it for simpliciy due to the mask and uni-directional lstm.
postnet_lstm_output, _ = self.lstm(postnet_fsmn_output)
mel_residual_output = self.fc(postnet_lstm_output)
return mel_residual_output
def mel_recon_loss_fn(output_lengths,
mel_targets,
dec_outputs,
postnet_outputs=None):
mae_loss = nn.L1Loss(reduction='none')
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1))
output_masks = ~output_masks
valid_outputs = output_masks.sum()
mel_loss_ = torch.sum(
mae_loss(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
if postnet_outputs is not None:
mel_loss = torch.sum(
mae_loss(mel_targets, postnet_outputs)
* output_masks.unsqueeze(-1)) / (
valid_outputs * mel_targets.size(-1))
else:
mel_loss = 0.0
return mel_loss_, mel_loss
def prosody_recon_loss_fn(input_lengths, duration_targets, pitch_targets,
energy_targets, log_duration_predictions,
pitch_predictions, energy_predictions):
mae_loss = nn.L1Loss(reduction='none')
input_masks = get_mask_from_lengths(
input_lengths, max_len=duration_targets.size(1))
input_masks = ~input_masks
valid_inputs = input_masks.sum()
dur_loss = torch.sum(
mae_loss(
torch.log(duration_targets.float() + 1), log_duration_predictions)
* input_masks) / valid_inputs
pitch_loss = torch.sum(
mae_loss(pitch_targets, pitch_predictions)
* input_masks) / valid_inputs
energy_loss = torch.sum(
mae_loss(energy_targets, energy_predictions)
* input_masks) / valid_inputs
return dur_loss, pitch_loss, energy_loss
class KanTtsSAMBERT(nn.Module):
def __init__(self, config, ling_unit_size):
super(KanTtsSAMBERT, self).__init__()
self.text_encoder = TextFftEncoder(config, ling_unit_size)
self.spk_tokenizer = nn.Embedding(ling_unit_size['speaker'],
config['am']['speaker_units'])
self.emo_tokenizer = nn.Embedding(ling_unit_size['emotion'],
config['am']['emotion_units'])
self.variance_adaptor = VarianceAdaptor(config)
self.mel_decoder = MelPNCADecoder(config)
self.mel_postnet = PostNet(config)
def get_lfr_mask_from_lengths(self, lengths, max_len):
batch_size = lengths.size(0)
# padding according to the outputs_per_step
padded_lr_lengths = torch.zeros_like(lengths)
for i in range(batch_size):
len_item = int(lengths[i].item())
padding = self.mel_decoder.r - len_item % self.mel_decoder.r
if (padding < self.mel_decoder.r):
padded_lr_lengths[i] = (len_item
+ padding) // self.mel_decoder.r
else:
padded_lr_lengths[i] = len_item // self.mel_decoder.r
return get_mask_from_lengths(
padded_lr_lengths, max_len=max_len // self.mel_decoder.r)
def forward(self,
inputs_ling,
inputs_emotion,
inputs_speaker,
input_lengths,
output_lengths=None,
mel_targets=None,
duration_targets=None,
pitch_targets=None,
energy_targets=None):
batch_size = inputs_ling.size(0)
input_masks = get_mask_from_lengths(
input_lengths, max_len=inputs_ling.size(1))
text_hid, enc_sla_attn_lst = self.text_encoder(
inputs_ling, input_masks, return_attns=True)
emo_hid = self.emo_tokenizer(inputs_emotion)
spk_hid = self.spk_tokenizer(inputs_speaker)
if output_lengths is not None:
output_masks = get_mask_from_lengths(
output_lengths, max_len=mel_targets.size(1))
else:
output_masks = None
(LR_text_outputs, LR_emo_outputs, LR_spk_outputs, LR_length_rounded,
log_duration_predictions, pitch_predictions,
energy_predictions) = self.variance_adaptor(
text_hid,
emo_hid,
spk_hid,
masks=input_masks,
output_masks=output_masks,
duration_targets=duration_targets,
pitch_targets=pitch_targets,
energy_targets=energy_targets)
if output_lengths is not None:
lfr_masks = self.get_lfr_mask_from_lengths(
output_lengths, max_len=LR_text_outputs.size(1))
else:
output_masks = get_mask_from_lengths(
LR_length_rounded, max_len=LR_text_outputs.size(1))
lfr_masks = None
# LFR with the factor of outputs_per_step
LFR_text_inputs = LR_text_outputs.contiguous().view(
batch_size, -1, self.mel_decoder.r * text_hid.shape[-1])
LFR_emo_inputs = LR_emo_outputs.contiguous().view(
batch_size, -1,
self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]]
LFR_spk_inputs = LR_spk_outputs.contiguous().view(
batch_size, -1,
self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]]
memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs],
dim=-1)
if duration_targets is not None:
x_band_width = int(
duration_targets.float().masked_fill(input_masks, 0).max()
/ self.mel_decoder.r + 0.5)
h_band_width = x_band_width
else:
x_band_width = int((torch.exp(log_duration_predictions) - 1).max()
/ self.mel_decoder.r + 0.5)
h_band_width = x_band_width
dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder(
memory,
x_band_width,
h_band_width,
target=mel_targets,
mask=lfr_masks,
return_attns=True)
# De-LFR with the factor of outputs_per_step
dec_outputs = dec_outputs.contiguous().view(batch_size, -1,
self.mel_decoder.d_mel)
if output_masks is not None:
dec_outputs = dec_outputs.masked_fill(
output_masks.unsqueeze(-1), 0)
postnet_outputs = self.mel_postnet(dec_outputs,
output_masks) + dec_outputs
if output_masks is not None:
postnet_outputs = postnet_outputs.masked_fill(
output_masks.unsqueeze(-1), 0)
res = {
'x_band_width': x_band_width,
'h_band_width': h_band_width,
'enc_slf_attn_lst': enc_sla_attn_lst,
'pnca_x_attn_lst': pnca_x_attn_lst,
'pnca_h_attn_lst': pnca_h_attn_lst,
'dec_outputs': dec_outputs,
'postnet_outputs': postnet_outputs,
'LR_length_rounded': LR_length_rounded,
'log_duration_predictions': log_duration_predictions,
'pitch_predictions': pitch_predictions,
'energy_predictions': energy_predictions
}
res['LR_text_outputs'] = LR_text_outputs
res['LR_emo_outputs'] = LR_emo_outputs
res['LR_spk_outputs'] = LR_spk_outputs
return res

View File

@@ -1,3 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from .utils import * # noqa F403

View File

@@ -1,136 +0,0 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import glob
import os
import shutil
import matplotlib
import matplotlib.pylab as plt
import torch
matplotlib.use('Agg')
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
def build_env(config, config_name, path):
t_path = os.path.join(path, config_name)
if config != t_path:
os.makedirs(path, exist_ok=True)
shutil.copyfile(config, os.path.join(path, config_name))
def plot_spectrogram(spectrogram):
fig, ax = plt.subplots(figsize=(10, 2))
im = ax.imshow(
spectrogram, aspect='auto', origin='lower', interpolation='none')
plt.colorbar(im, ax=ax)
fig.canvas.draw()
plt.close()
return fig
def plot_alignment(alignment, info=None):
fig, ax = plt.subplots()
im = ax.imshow(
alignment, aspect='auto', origin='lower', interpolation='none')
fig.colorbar(im, ax=ax)
xlabel = 'Input timestep'
if info is not None:
xlabel += '\t' + info
plt.xlabel(xlabel)
plt.ylabel('Output timestep')
fig.canvas.draw()
plt.close()
return fig
def load_checkpoint(filepath, device):
assert os.path.isfile(filepath)
checkpoint_dict = torch.load(filepath, map_location=device)
return checkpoint_dict
def save_checkpoint(filepath, obj):
torch.save(obj, filepath)
def scan_checkpoint(cp_dir, prefix):
pattern = os.path.join(cp_dir, prefix + '????????.pkl')
cp_list = glob.glob(pattern)
if len(cp_list) == 0:
return None
return sorted(cp_list)[-1]
def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
class ValueWindow():
def __init__(self, window_size=100):
self._window_size = window_size
self._values = []
def append(self, x):
self._values = self._values[-(self._window_size - 1):] + [x]
@property
def sum(self):
return sum(self._values)
@property
def count(self):
return len(self._values)
@property
def average(self):
return self.sum / max(1, self.count)
def reset(self):
self._values = []
def get_model_size(model):
param_num = sum([p.numel() for p in model.parameters() if p.requires_grad])
param_size = param_num * 4 / 1024 / 1024
return param_size
def get_grad_norm(model):
total_norm = 0
params = [
p for p in model.parameters() if p.grad is not None and p.requires_grad
]
for p in params:
param_norm = p.grad.detach().data.norm(2)
total_norm += param_norm.item()**2
total_norm = total_norm**0.5
return total_norm
def init_weights(m, mean=0.0, std=0.01):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
m.weight.data.normal_(mean, std)
def get_mask_from_lengths(lengths, max_len=None):
batch_size = lengths.shape[0]
if max_len is None:
max_len = torch.max(lengths).item()
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
-1).to(lengths.device)
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
return mask

View File

@@ -2,24 +2,31 @@
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import datetime
import os
import shutil
import zipfile
import json
import numpy as np
import yaml
from modelscope.metainfo import Models
from modelscope.models.base import Model
from modelscope.models.builder import MODELS
from modelscope.utils.audio.audio_utils import TtsTrainType
from modelscope.utils.audio.tts_exceptions import (
TtsFrontendInitializeFailedException,
TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException,
TtsVoiceNotExistsException)
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from .voice import Voice
__all__ = ['SambertHifigan']
logger = get_logger()
@MODELS.register_module(
Tasks.text_to_speech, module_name=Models.sambert_hifigan)
@@ -27,18 +34,12 @@ class SambertHifigan(Model):
def __init__(self, model_dir, *args, **kwargs):
super().__init__(model_dir, *args, **kwargs)
if 'am' not in kwargs:
raise TtsModelConfigurationException(
'modelscope error: configuration model field missing am!')
if 'vocoder' not in kwargs:
raise TtsModelConfigurationException(
'modelscope error: configuration model field missing vocoder!')
if 'lang_type' not in kwargs:
raise TtsModelConfigurationException(
'modelscope error: configuration model field missing lang_type!'
)
am_cfg = kwargs['am']
voc_cfg = kwargs['vocoder']
self.__model_dir = model_dir
self.__is_train = False
if 'is_train' in kwargs:
is_train = kwargs['is_train']
if isinstance(is_train, bool):
self.__is_train = is_train
# initialize frontend
import ttsfrd
frontend = ttsfrd.TtsFrontendEngine()
@@ -55,33 +56,155 @@ class SambertHifigan(Model):
'modelscope error: language type invalid: {}'.format(
kwargs['lang_type']))
self.__frontend = frontend
zip_file = os.path.join(model_dir, 'voices.zip')
self.__voice_path = os.path.join(model_dir, 'voices')
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
zip_ref.extractall(model_dir)
voice_cfg_path = os.path.join(self.__voice_path, 'voices.json')
with open(voice_cfg_path, 'r', encoding='utf-8') as f:
voice_cfg = json.load(f)
if 'voices' not in voice_cfg:
raise TtsModelConfigurationException(
'modelscope error: voices invalid')
self.__voice = {}
for name in voice_cfg['voices']:
voice_path = os.path.join(self.__voice_path, name)
if not os.path.exists(voice_path):
continue
self.__voice[name] = Voice(name, voice_path, am_cfg, voc_cfg)
if voice_cfg['voices']:
self.__default_voice_name = voice_cfg['voices'][0]
self.__voices, self.__voice_cfg = self.load_voice(model_dir)
if len(self.__voices) == 0 or len(self.__voice_cfg) == 0:
raise TtsVoiceNotExistsException('modelscope error: voices empty')
if self.__voice_cfg['voices']:
self.__default_voice_name = self.__voice_cfg['voices'][0]
else:
raise TtsVoiceNotExistsException(
'modelscope error: voices is empty in voices.json')
def load_voice(self, model_dir):
voices = {}
voices_path = os.path.join(model_dir, 'voices')
voices_json_path = os.path.join(voices_path, 'voices.json')
if not os.path.exists(voices_path) or not os.path.exists(
voices_json_path):
return voices, []
with open(voices_json_path, 'r', encoding='utf-8') as f:
voice_cfg = json.load(f)
if 'voices' not in voice_cfg:
return voices, []
for name in voice_cfg['voices']:
voice_path = os.path.join(voices_path, name)
if not os.path.exists(voice_path):
continue
voices[name] = Voice(name, voice_path)
return voices, voice_cfg
def save_voices(self):
voices_json_path = os.path.join(self.__model_dir, 'voices',
'voices.json')
if os.path.exists(voices_json_path):
os.remove(voices_json_path)
save_voices = {}
save_voices['voices'] = []
for k in self.__voices.keys():
save_voices['voices'].append(k)
with open(voices_json_path, 'w', encoding='utf-8') as f:
json.dump(save_voices, f)
def get_voices(self):
return self.__voices, self.__voice_cfg
def create_empty_voice(self, voice_name, audio_config, am_config_path,
voc_config_path):
voice_name_path = os.path.join(self.__model_dir, 'voices', voice_name)
if os.path.exists(voice_name_path):
shutil.rmtree(voice_name_path)
os.makedirs(voice_name_path, exist_ok=True)
if audio_config and os.path.exists(audio_config) and os.path.isfile(
audio_config):
shutil.copy(audio_config, voice_name_path)
voice_am_path = os.path.join(voice_name_path, 'am')
voice_voc_path = os.path.join(voice_name_path, 'voc')
if am_config_path and os.path.exists(
am_config_path) and os.path.isfile(am_config):
am_config_name = os.path.join(voice_am_path, 'config.yaml')
shutil.copy(am_config_path, am_config_name)
if voc_config_path and os.path.exists(
voc_config_path) and os.path.isfile(voc_config):
voc_config_name = os.path.join(voice_am_path, 'config.yaml')
shutil.copy(voc_config_path, voc_config_name)
am_ckpt_path = os.path.join(voice_am_path, 'ckpt')
voc_ckpt_path = os.path.join(voice_voc_path, 'ckpt')
os.makedirs(am_ckpt_path, exist_ok=True)
os.makedirs(voc_ckpt_path, exist_ok=True)
self.__voices[voice_name] = Voice(
voice_name=voice_name,
voice_path=voice_name_path,
allow_empty=True)
def get_voice_audio_config_path(self, voice):
if voice not in self.__voices:
return ''
return self.__voices[voice].audio_config
def get_voice_lang_path(self, voice):
if voice not in self.__voices:
return ''
return self.__voices[voice].lang_dir
def __synthesis_one_sentences(self, voice_name, text):
if voice_name not in self.__voice:
if voice_name not in self.__voices:
raise TtsVoiceNotExistsException(
f'modelscope error: Voice {voice_name} not exists')
return self.__voice[voice_name].forward(text)
return self.__voices[voice_name].forward(text)
def train(self,
voice,
dirs,
train_type,
configs_path=None,
ignore_pretrain=False,
create_if_not_exists=False,
hparam=None):
work_dir = dirs['work_dir']
am_dir = dirs['am_tmp_dir']
voc_dir = dirs['voc_tmp_dir']
data_dir = dirs['data_dir']
if voice not in self.__voices:
if not create_if_not_exists:
raise TtsVoiceNotExistsException(
f'modelscope error: Voice {voice_name} not exists')
am_config = configs_path.get('am_config', None)
voc_config = configs_path.get('voc_config', None)
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type and not am_config:
raise TtsTrainingCfgNotExistsException(
'training new voice am with empty am_config')
if TtsTrainType.TRAIN_TYPE_VOC in train_type and not voc_config:
raise TtsTrainingCfgNotExistsException(
'training new voice voc with empty voc_config')
target_voice = self.__voices[voice]
am_config_path = target_voice.am_config
voc_config_path = target_voice.voc_config
if not configs_path:
am_config = configs_path.get('am_config', None)
if am_config:
am_config_path = am_config
voc_config = configs_path.get('voc_config', None)
if voc_config:
voc_config_path = voc_config
logger.info('Start training....')
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type:
logger.info('Start SAMBERT training...')
totaltime = datetime.datetime.now()
hparams = train_type[TtsTrainType.TRAIN_TYPE_SAMBERT]
target_voice.train_sambert(work_dir, am_dir, data_dir,
am_config_path, ignore_pretrain,
hparams)
totaltime = datetime.datetime.now() - totaltime
logger.info('SAMBERT training spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))
else:
logger.info('skip SAMBERT training...')
if TtsTrainType.TRAIN_TYPE_VOC in train_type:
logger.info('Start HIFIGAN training...')
totaltime = datetime.datetime.now()
hparams = train_type[TtsTrainType.TRAIN_TYPE_VOC]
target_voice.train_hifigan(work_dir, voc_dir, data_dir,
voc_config_path, ignore_pretrain,
hparams)
totaltime = datetime.datetime.now() - totaltime
logger.info('HIFIGAN training spent: {:.2f} hours\n'.format(
totaltime.total_seconds() / 3600.0))
else:
logger.info('skip HIFIGAN training...')
def forward(self, text: str, voice_name: str = None):
voice = self.__default_voice_name

View File

@@ -2,75 +2,121 @@
import os
import pickle as pkl
import time
from collections import OrderedDict
from threading import Lock
import json
import numpy as np
import torch
import yaml
from torch.utils.data import DataLoader
from modelscope.utils.audio.tts_exceptions import \
TtsModelConfigurationException
from modelscope import __version__
from modelscope.utils.audio.tts_exceptions import (
TtsModelConfigurationException, TtsModelNotExistsException)
from modelscope.utils.constant import ModelFile, Tasks
from .models.datasets.units import KanTtsLinguisticUnit
from .models.models.hifigan import Generator
from .models.models.sambert import KanTtsSAMBERT
from .models.utils import (AttrDict, build_env, init_weights, load_checkpoint,
plot_spectrogram, save_checkpoint, scan_checkpoint)
from modelscope.utils.logger import get_logger
MAX_WAV_VALUE = 32768.0
from modelscope.models.audio.tts.kantts import ( # isort:skip; isort:skip
GAN_Trainer, Generator, KanTtsLinguisticUnit, Sambert_Trainer,
criterion_builder, get_am_datasets, get_voc_datasets, model_builder)
logger = get_logger()
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Voice:
def __init__(self, voice_name, voice_path, am_config, voc_config):
def __init__(self, voice_name, voice_path, allow_empty=False):
self.__voice_name = voice_name
self.__voice_path = voice_path
self.__am_config = AttrDict(**am_config)
self.__voc_config = AttrDict(**voc_config)
self.distributed = False
self.local_rank = 0
am_config_path = os.path.join(
os.path.join(voice_path, 'am'), 'config.yaml')
voc_config_path = os.path.join(
os.path.join(voice_path, 'voc'), 'config.yaml')
self.audio_config = os.path.join(voice_path, 'audio_config.yaml')
self.lang_dir = os.path.join(voice_path, 'dict')
self.am_config = am_config_path
self.voc_config = voc_config_path
am_ckpt = os.path.join(os.path.join(voice_path, 'am'), 'ckpt')
voc_ckpt = os.path.join(os.path.join(voice_path, 'voc'), 'ckpt')
self.__am_ckpts = self.scan_ckpt(am_ckpt)
self.__voc_ckpts = self.scan_ckpt(voc_ckpt)
if not os.path.exists(am_config_path):
raise TtsModelConfigurationException(
'modelscope error: am configuration not found')
if not os.path.exists(voc_config_path):
raise TtsModelConfigurationException(
'modelscope error: voc configuration not found')
if not allow_empty:
if len(self.__am_ckpts) == 0:
raise TtsModelNotExistsException(
'modelscope error: am model file not found')
if len(self.__voc_ckpts) == 0:
raise TtsModelNotExistsException(
'modelscope error: voc model file not found')
with open(am_config_path, 'r') as f:
self.__am_config = yaml.load(f, Loader=yaml.Loader)
with open(voc_config_path, 'r') as f:
self.__voc_config = yaml.load(f, Loader=yaml.Loader)
self.__model_loaded = False
self.__lock = Lock()
if 'am' not in self.__am_config:
raise TtsModelConfigurationException(
'modelscope error: am configuration invalid')
if 'linguistic_unit' not in self.__am_config:
raise TtsModelConfigurationException(
'modelscope error: am configuration invalid')
self.__am_lingustic_unit_config = self.__am_config['linguistic_unit']
self.__ling_unit = KanTtsLinguisticUnit(self.__am_config,
self.lang_dir)
self.__ling_unit_size = self.__ling_unit.get_unit_size()
self.__am_config['Model']['KanTtsSAMBERT']['params'].update(
self.__ling_unit_size)
if torch.cuda.is_available():
self.__device = torch.device('cuda')
else:
self.__device = torch.device('cpu')
def scan_ckpt(self, ckpt_path):
filelist = os.listdir(ckpt_path)
if len(filelist) == 0:
return {}
ckpts = {}
for filename in filelist:
# checkpoint_X.pth
if len(filename) - 15 <= 0:
continue
if filename[-4:] == '.pth' and filename[0:10] == 'checkpoint':
filename_prefix = filename.split('.')[0]
idx = int(filename_prefix.split('_')[-1])
path = os.path.join(ckpt_path, filename)
ckpts[idx] = path
od = OrderedDict(sorted(ckpts.items()))
return od
def __load_am(self):
local_am_ckpt_path = os.path.join(self.__voice_path, 'am')
self.__am_ckpt_path = os.path.join(local_am_ckpt_path,
ModelFile.TORCH_MODEL_BIN_FILE)
has_mask = True
if 'has_mask' in self.__am_lingustic_unit_config:
has_mask = self.__am_lingustic_unit_config.has_mask
self.__ling_unit = KanTtsLinguisticUnit(
self.__am_lingustic_unit_config, self.__voice_path, has_mask)
self.__am_net = KanTtsSAMBERT(self.__am_config,
self.__ling_unit.get_unit_size()).to(
self.__device)
state_dict_g = {}
try:
state_dict_g = load_checkpoint(self.__am_ckpt_path, self.__device)
except RuntimeError:
with open(self.__am_ckpt_path, 'rb') as f:
pth_var_dict = pkl.load(f)
state_dict_g['fsnet'] = {
k: torch.FloatTensor(v)
for k, v in pth_var_dict['fsnet'].items()
}
self.__am_net.load_state_dict(state_dict_g['fsnet'], strict=False)
self.__am_net.eval()
self.__am_model, _, _ = model_builder(self.__am_config, self.__device)
self.__am = self.__am_model['KanTtsSAMBERT']
state_dict = torch.load(self.__am_ckpts[next(
reversed(self.__am_ckpts))])
self.__am.load_state_dict(state_dict['model'], strict=False)
self.__am.eval()
def __load_vocoder(self):
local_voc_ckpy_path = os.path.join(self.__voice_path, 'vocoder')
self.__voc_ckpt_path = os.path.join(local_voc_ckpy_path,
ModelFile.TORCH_MODEL_BIN_FILE)
self.__generator = Generator(self.__voc_config).to(self.__device)
state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device)
self.__generator.load_state_dict(state_dict_g['generator'])
self.__generator.eval()
self.__generator.remove_weight_norm()
self.__voc_model = Generator(
**self.__voc_config['Model']['Generator']['params'])
states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))])
self.__voc_model.load_state_dict(states['model']['generator'])
if self.__voc_config['Model']['Generator']['params'][
'out_channels'] > 1:
from .kantts.models.pqmf import PQMF
self.__voc_model = PQMF()
self.__voc_model.remove_weight_norm()
self.__voc_model.eval().to(self.__device)
def __am_forward(self, symbol_seq):
with self.__lock:
@@ -92,43 +138,283 @@ class Voice:
self.__device).unsqueeze(0)
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to(
self.__device).unsqueeze(0)
inputs_len = torch.zeros(1).to(self.__device).long(
) + inputs_emo.size(1) - 1 # minus 1 for "~"
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
inputs_spk[:, :-1], inputs_len)
inputs_len = (torch.zeros(1).to(self.__device).long()
+ inputs_emo.size(1) - 1) # minus 1 for "~"
res = self.__am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
inputs_spk[:, :-1], inputs_len)
postnet_outputs = res['postnet_outputs']
LR_length_rounded = res['LR_length_rounded']
valid_length = int(LR_length_rounded[0].item())
postnet_outputs = postnet_outputs[
0, :valid_length, :].cpu().numpy()
postnet_outputs = postnet_outputs[0, :valid_length, :].cpu()
return postnet_outputs
def __vocoder_forward(self, melspec):
dim0 = list(melspec.shape)[-1]
if dim0 != self.__voc_config.num_mels:
raise TtsVocoderMelspecShapeMismatchException(
'modelscope error: input melspec mismatch require {} but {}'.
format(self.__voc_config.num_mels, dim0))
with torch.no_grad():
x = melspec.T
x = torch.FloatTensor(x).to(self.__device)
if len(x.shape) == 2:
x = x.unsqueeze(0)
y_g_hat = self.__generator(x)
audio = y_g_hat.squeeze()
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype('int16')
return audio
x = melspec.to(self.__device)
x = x.transpose(1, 0).unsqueeze(0)
y = self.__voc_model(x)
if hasattr(self.__voc_model, 'pqmf'):
y = self.__voc_model.synthesis(y)
y = y.view(-1).cpu().numpy()
return y
def train_sambert(self,
work_dir,
stage_dir,
data_dir,
config_path,
ignore_pretrain=False,
hparams=dict()):
logger.info('TRAIN SAMBERT....')
if len(self.__am_ckpts) == 0:
raise TtsTrainingInvalidModelException(
'resume pretrain but model is empty')
from_steps = hparams.get('resume_from_steps', -1)
if from_steps < 0:
from_latest = hparams.get('resume_from_latest', True)
else:
from_latest = hparams.get('resume_from_latest', False)
train_steps = hparams.get('train_steps', 0)
with open(self.audio_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
with open(config_path, 'r') as f:
config.update(yaml.load(f, Loader=yaml.Loader))
config.update(hparams)
resume_from = None
if from_latest:
from_steps = next(reversed(self.__am_ckpts))
resume_from = self.__am_ckpts[from_steps]
if not os.path.exists(resume_from):
raise TtsTrainingInvalidModelException(
f'latest model:{resume_from} not exists')
else:
if from_steps not in self.__am_ckpts:
raise TtsTrainingInvalidModelException(
f'no such model from steps:{from_steps}')
else:
resume_from = self.__am_ckpts[from_steps]
if train_steps > 0:
train_max_steps = train_steps + from_steps
config['train_max_steps'] = train_max_steps
logger.info(f'TRAINING steps: {train_max_steps}')
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime())
config['modelscope_version'] = __version__
with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
for key, value in config.items():
logger.info(f'{key} = {value}')
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
meta_file = [
os.path.join(
d,
'raw_metafile.txt' if not fp_enable else 'fprm_metafile.txt')
for d in data_dir
]
train_dataset, valid_dataset = get_am_datasets(meta_file, data_dir,
self.lang_dir, config,
config['allow_cache'])
logger.info(f'The number of training files = {len(train_dataset)}.')
logger.info(f'The number of validation files = {len(valid_dataset)}.')
sampler = {'train': None, 'valid': None}
train_dataloader = DataLoader(
train_dataset,
shuffle=False if self.distributed else True,
collate_fn=train_dataset.collate_fn,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
sampler=sampler['train'],
pin_memory=config['pin_memory'],
)
valid_dataloader = DataLoader(
valid_dataset,
shuffle=False if self.distributed else True,
collate_fn=valid_dataset.collate_fn,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
sampler=sampler['valid'],
pin_memory=config['pin_memory'],
)
ling_unit_size = train_dataset.ling_unit.get_unit_size()
config['Model']['KanTtsSAMBERT']['params'].update(ling_unit_size)
model, optimizer, scheduler = model_builder(config, self.__device,
self.local_rank,
self.distributed)
criterion = criterion_builder(config, self.__device)
trainer = Sambert_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=self.__device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
max_steps=train_max_steps,
save_dir=stage_dir,
save_interval=config['save_interval_steps'],
valid_interval=config['eval_interval_steps'],
log_interval=config['log_interval'],
grad_clip=config['grad_norm'],
)
if resume_from is not None:
trainer.load_checkpoint(resume_from, True, True)
logger.info(f'Successfully resumed from {resume_from}.')
try:
trainer.train()
except (Exception, KeyboardInterrupt) as e:
logger.error(e, exc_info=True)
trainer.save_checkpoint(
os.path.join(
os.path.join(stage_dir, 'ckpt'),
f'checkpoint-{trainer.steps}.pth'))
logger.info(
f'Successfully saved checkpoint @ {trainer.steps}steps.')
def train_hifigan(self,
work_dir,
stage_dir,
data_dir,
config_path,
ignore_pretrain=False,
hparams=dict()):
logger.info('TRAIN HIFIGAN....')
if len(self.__voc_ckpts) == 0:
raise TtsTrainingInvalidModelException(
'resume pretrain but model is empty')
from_steps = hparams.get('resume_from_steps', -1)
if from_steps < 0:
from_latest = hparams.get('resume_from_latest', True)
else:
from_latest = hparams.get('resume_from_latest', False)
train_steps = hparams.get('train_steps', 0)
with open(self.audio_config, 'r') as f:
config = yaml.load(f, Loader=yaml.Loader)
with open(config_path, 'r') as f:
config.update(yaml.load(f, Loader=yaml.Loader))
config.update(hparams)
resume_from = None
if from_latest:
from_steps = next(reversed(self.__voc_ckpts))
resume_from = self.__voc_ckpts[from_steps]
if not os.path.exists(resume_from):
raise TtsTrainingInvalidModelException(
f'latest model:{resume_from} not exists')
else:
if from_steps not in self.__voc_ckpts:
raise TtsTrainingInvalidModelException(
f'no such model from steps:{from_steps}')
else:
resume_from = self.__voc_ckpts[from_steps]
if train_steps > 0:
train_max_steps = train_steps
config['train_max_steps'] = train_max_steps
logger.info(f'TRAINING steps: {train_max_steps}')
logger.info(f'resume from: {resume_from}')
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
time.localtime())
config['modelscope_version'] = __version__
with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
for key, value in config.items():
logger.info(f'{key} = {value}')
train_dataset, valid_dataset = get_voc_datasets(config, data_dir)
logger.info(f'The number of training files = {len(train_dataset)}.')
logger.info(f'The number of validation files = {len(valid_dataset)}.')
sampler = {'train': None, 'valid': None}
train_dataloader = DataLoader(
train_dataset,
shuffle=False if self.distributed else True,
collate_fn=train_dataset.collate_fn,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
sampler=sampler['train'],
pin_memory=config['pin_memory'],
)
valid_dataloader = DataLoader(
valid_dataset,
shuffle=False if self.distributed else True,
collate_fn=valid_dataset.collate_fn,
batch_size=config['batch_size'],
num_workers=config['num_workers'],
sampler=sampler['valid'],
pin_memory=config['pin_memory'],
)
model, optimizer, scheduler = model_builder(config, self.__device,
self.local_rank,
self.distributed)
criterion = criterion_builder(config, self.__device)
trainer = GAN_Trainer(
config=config,
model=model,
optimizer=optimizer,
scheduler=scheduler,
criterion=criterion,
device=self.__device,
sampler=sampler,
train_loader=train_dataloader,
valid_loader=valid_dataloader,
max_steps=train_max_steps,
save_dir=stage_dir,
save_interval=config['save_interval_steps'],
valid_interval=config['eval_interval_steps'],
log_interval=config['log_interval_steps'],
)
if resume_from is not None:
trainer.load_checkpoint(resume_from)
logger.info(f'Successfully resumed from {resume_from}.')
try:
trainer.train()
except (Exception, KeyboardInterrupt) as e:
logger.error(e, exc_info=True)
trainer.save_checkpoint(
os.path.join(
os.path.join(stage_dir, 'ckpt'),
f'checkpoint-{trainer.steps}.pth'))
logger.info(
f'Successfully saved checkpoint @ {trainer.steps}steps.')
def forward(self, symbol_seq):
with self.__lock:
if not self.__model_loaded:
torch.manual_seed(self.__am_config.seed)
if torch.cuda.is_available():
torch.manual_seed(self.__am_config.seed)
self.__device = torch.device('cuda')
else:
self.__device = torch.device('cpu')
self.__load_am()
self.__load_vocoder()
self.__model_loaded = True

View File

@@ -14,6 +14,7 @@ if TYPE_CHECKING:
ImageInstanceSegmentationPreprocessor,
ImageDenoisePreprocessor)
from .kws import WavToLists
from .tts import KanttsDataPreprocessor
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor)
from .nlp import (
DocumentSegmentationTransformersPreprocessor,
@@ -50,6 +51,7 @@ else:
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor'
],
'kws': ['WavToLists'],
'tts': ['KanttsDataPreprocessor'],
'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'],
'nlp': [
'DocumentSegmentationTransformersPreprocessor',

View File

@@ -0,0 +1,60 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import Any, Dict, List, Union
from modelscope.metainfo import Preprocessors
from modelscope.models.audio.tts.kantts.preprocess.data_process import \
process_data
from modelscope.models.base import Model
from modelscope.utils.audio.tts_exceptions import (
TtsDataPreprocessorAudioConfigNotExistsException,
TtsDataPreprocessorDirNotExistsException)
from modelscope.utils.constant import Fields, Frameworks, Tasks
from .base import Preprocessor
from .builder import PREPROCESSORS
__all__ = ['KanttsDataPreprocessor']
@PREPROCESSORS.register_module(
group_key=Tasks.text_to_speech,
module_name=Preprocessors.kantts_data_preprocessor)
class KanttsDataPreprocessor(Preprocessor):
def __init__(self):
pass
def __call__(self,
data_dir,
output_dir,
lang_dir,
audio_config_path,
speaker_name='F7',
target_lang='PinYin',
skip_script=False):
self.do_data_process(data_dir, output_dir, lang_dir, audio_config_path,
speaker_name, target_lang, skip_script)
def do_data_process(self,
datadir,
outputdir,
langdir,
audio_config,
speaker_name='F7',
targetLang='PinYin',
skip_script=False):
if not os.path.exists(datadir):
raise TtsDataPreprocessorDirNotExistsException(
'Preprocessor: dataset dir not exists')
if not os.path.exists(outputdir):
raise TtsDataPreprocessorDirNotExistsException(
'Preprocessor: output dir not exists')
if not os.path.exists(audio_config):
raise TtsDataPreprocessorAudioConfigNotExistsException(
'Preprocessor: audio config not exists')
if not os.path.exists(langdir):
raise TtsDataPreprocessorDirNotExistsException(
'Preprocessor: language dir not exists')
process_data(datadir, outputdir, langdir, audio_config, speaker_name,
targetLang, skip_script)

View File

@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
from .audio.ans_trainer import ANSTrainer
from .audio import ANSTrainer, KanttsTrainer
from .base import DummyTrainer
from .builder import build_trainer
from .cv import (ImageInstanceSegmentationTrainer,
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
else:
_import_structure = {
'audio.ans_trainer': ['ANSTrainer'],
'audio': ['ANSTrainer', 'KanttsTrainer'],
'base': ['DummyTrainer'],
'builder': ['build_trainer'],
'cv': [

View File

@@ -0,0 +1,25 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import TYPE_CHECKING
from modelscope.utils.import_utils import LazyImportModule
if TYPE_CHECKING:
print('TYPE_CHECKING...')
from .tts_trainer import KanttsTrainer
from .ans_trainer import ANSTrainer
else:
_import_structure = {
'tts_trainer': ['KanttsTrainer'],
'ans_trainer': ['ANSTrainer']
}
import sys
sys.modules[__name__] = LazyImportModule(
__name__,
globals()['__file__'],
_import_structure,
module_spec=__spec__,
extra_objects={},
)

View File

@@ -0,0 +1,235 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
from typing import Callable, Dict, List, Optional, Tuple, Union
import json
from modelscope.metainfo import Preprocessors, Trainers
from modelscope.models.audio.tts import SambertHifigan
from modelscope.msdatasets import MsDataset
from modelscope.preprocessors.builder import build_preprocessor
from modelscope.trainers.base import BaseTrainer
from modelscope.trainers.builder import TRAINERS
from modelscope.utils.audio.audio_utils import TtsTrainType
from modelscope.utils.audio.tts_exceptions import (
TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
TtsTrainingWorkDirNotExistsException)
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
DEFAULT_DATASET_REVISION,
DEFAULT_MODEL_REVISION, ModelFile,
Tasks, TrainerStages)
from modelscope.utils.data_utils import to_device
from modelscope.utils.logger import get_logger
logger = get_logger()
@TRAINERS.register_module(module_name=Trainers.speech_kantts_trainer)
class KanttsTrainer(BaseTrainer):
DATA_DIR = 'data'
AM_TMP_DIR = 'tmp_am'
VOC_TMP_DIR = 'tmp_voc'
ORIG_MODEL_DIR = 'orig_model'
def __init__(self,
model: str,
work_dir: str = None,
speaker: str = 'F7',
lang_type: str = 'PinYin',
cfg_file: Optional[str] = None,
train_dataset: Optional[Union[MsDataset, str]] = None,
train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
train_dataset_revision: str = DEFAULT_DATASET_REVISION,
train_type: dict = {
TtsTrainType.TRAIN_TYPE_SAMBERT: {},
TtsTrainType.TRAIN_TYPE_VOC: {}
},
preprocess_skip_script=False,
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
**kwargs):
if not work_dir:
self.work_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.work_dir):
os.makedirs(self.work_dir)
else:
self.work_dir = work_dir
if not os.path.exists(self.work_dir):
raise TtsTrainingWorkDirNotExistsException(
f'{self.work_dir} not exists')
self.train_type = dict()
if isinstance(train_type, dict):
for k, v in train_type.items():
if (k == TtsTrainType.TRAIN_TYPE_SAMBERT
or k == TtsTrainType.TRAIN_TYPE_VOC
or k == TtsTrainType.TRAIN_TYPE_BERT):
self.train_type[k] = v
if len(self.train_type) == 0:
logger.info('train type empty, default to sambert and voc')
self.train_type[TtsTrainType.TRAIN_TYPE_SAMBERT] = {}
self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}
logger.info(f'Set workdir to {self.work_dir}')
self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
self.orig_model_dir = os.path.join(self.work_dir, self.ORIG_MODEL_DIR)
self.raw_dataset_path = ''
self.skip_script = preprocess_skip_script
self.audio_config_path = ''
self.lang_path = ''
self.am_config_path = ''
self.voc_config_path = ''
shutil.rmtree(self.data_dir, ignore_errors=True)
shutil.rmtree(self.am_tmp_dir, ignore_errors=True)
shutil.rmtree(self.voc_tmp_dir, ignore_errors=True)
shutil.rmtree(self.orig_model_dir, ignore_errors=True)
os.makedirs(self.data_dir)
os.makedirs(self.am_tmp_dir)
os.makedirs(self.voc_tmp_dir)
if train_dataset:
if isinstance(train_dataset, str):
logger.info(f'load {train_dataset_namespace}/{train_dataset}')
train_dataset = MsDataset.load(
dataset_name=train_dataset,
namespace=train_dataset_namespace,
version=train_dataset_revision)
logger.info(f'train dataset:{train_dataset.config_kwargs}')
self.raw_dataset_path = self.load_dataset_raw_path(train_dataset)
model_dir = None
if os.path.exists(model):
model_dir = model
else:
model_dir = self.get_or_download_model_dir(model, model_revision)
shutil.copytree(model_dir, self.orig_model_dir)
self.model_dir = self.orig_model_dir
if not cfg_file:
cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
self.parse_cfg(cfg_file)
if not os.path.exists(self.raw_dataset_path):
raise TtsTrainingDatasetInvalidException(
'dataset raw path not exists')
self.finetune_from_pretrain = False
self.speaker = speaker
self.lang_type = lang_type
self.model = None
self.device = kwargs.get('device', 'gpu')
self.model = self.get_model(self.model_dir, self.speaker,
self.lang_type)
if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
self.audio_data_preprocessor = build_preprocessor(
dict(type=Preprocessors.kantts_data_preprocessor),
Tasks.text_to_speech)
def parse_cfg(self, cfg_file):
cur_dir = os.path.dirname(cfg_file)
with open(cfg_file, 'r', encoding='utf-8') as f:
config = json.load(f)
if 'train' not in config:
raise TtsTrainingInvalidModelException(
'model not support finetune')
if 'audio_config' in config['train']:
audio_config = os.path.join(cur_dir,
config['train']['audio_config'])
if os.path.exists(audio_config):
self.audio_config_path = audio_config
if 'am_config' in config['train']:
am_config = os.path.join(cur_dir, config['train']['am_config'])
if os.path.exists(am_config):
self.am_config_path = am_config
if 'voc_config' in config['train']:
voc_config = os.path.join(cur_dir,
config['train']['voc_config'])
if os.path.exists(voc_config):
self.voc_config_path = voc_config
if 'language_path' in config['train']:
lang_path = os.path.join(cur_dir,
config['train']['language_path'])
if os.path.exists(lang_path):
self.lang_path = lang_path
if not self.raw_dataset_path:
if 'train_dataset' in config['train']:
dataset = config['train']['train_dataset']
if 'id' in dataset:
namespace = dataset.get('namespace',
DEFAULT_DATASET_NAMESPACE)
revision = dataset.get('revision',
DEFAULT_DATASET_REVISION)
ms = MsDataset.load(
dataset_name=dataset['id'],
namespace=namespace,
version=revision)
self.raw_dataset_path = self.load_dataset_raw_path(ms)
elif 'path' in dataset:
self.raw_dataset_path = dataset['path']
def load_dataset_raw_path(self, dataset: MsDataset):
if 'split_config' not in dataset.config_kwargs:
raise TtsTrainingDatasetInvalidException(
'split_config not found in config_kwargs')
if 'train' not in dataset.config_kwargs['split_config']:
raise TtsTrainingDatasetInvalidException(
'no train split in split_config')
return dataset.config_kwargs['split_config']['train']
def prepare_data(self):
if self.audio_data_preprocessor:
audio_config = self.audio_config_path
if not audio_config or not os.path.exists(audio_config):
audio_config = self.model.get_voice_audio_config_path(
self.speaker)
lang_path = self.lang_path
if not lang_path or not os.path.exists(lang_path):
lang_path = self.model.get_voice_lang_path(self.speaker)
self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
lang_path, audio_config, self.speaker,
self.lang_type, self.skip_script)
def prepare_text(self):
pass
def get_model(self, model_dir, speaker, lang_type):
model = SambertHifigan(
model_dir=self.model_dir, lang_type=self.lang_type, is_train=True)
return model
def train(self, *args, **kwargs):
if not self.model:
raise TtsTrainingInvalidModelException('model is none')
ignore_pretrain = False
if 'ignore_pretrain' in kwargs:
ignore_pretrain = kwargs['ignore_pretrain']
if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
self.prepare_data()
if TtsTrainType.TRAIN_TYPE_BERT in self.train_type:
self.prepare_text()
dir_dict = {
'work_dir': self.work_dir,
'am_tmp_dir': self.am_tmp_dir,
'voc_tmp_dir': self.voc_tmp_dir,
'data_dir': self.data_dir
}
config_dict = {
'am_config': self.am_config_path,
'voc_config': self.voc_config_path
}
self.model.train(self.speaker, dir_dict, self.train_type, config_dict,
ignore_pretrain)
def evaluate(self, checkpoint_path: str, *args,
**kwargs) -> Dict[str, float]:
return {}

View File

@@ -9,6 +9,12 @@ from modelscope.fileio.file import HTTPStorage
SEGMENT_LENGTH_TRAIN = 16000
class TtsTrainType(object):
TRAIN_TYPE_SAMBERT = 'train-type-sambert'
TRAIN_TYPE_BERT = 'train-type-bert'
TRAIN_TYPE_VOC = 'train-type-voc'
def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
"""
Dataset mapping function to split one audio into segments.

View File

@@ -18,6 +18,12 @@ class TtsModelConfigurationException(TtsException):
pass
class TtsModelNotExistsException(TtsException):
"""
TTS model not exists exception.
"""
class TtsVoiceNotExistsException(TtsException):
"""
TTS voice not exists exception.
@@ -55,3 +61,58 @@ class TtsVocoderMelspecShapeMismatchException(TtsVocoderException):
"""
If vocoder's input melspec shape mismatch, this exception will be raised.
"""
class TtsDataPreprocessorException(TtsException):
"""
Tts data preprocess exception
"""
class TtsDataPreprocessorDirNotExistsException(TtsDataPreprocessorException):
"""
If any dir is not exists, this exception will be raised.
"""
class TtsDataPreprocessorAudioConfigNotExistsException(
TtsDataPreprocessorException):
"""
If audio config is not exists, this exception will be raised.
"""
class TtsTrainingException(TtsException):
"""
Tts training exception
"""
class TtsTrainingHparamsInvalidException(TtsException):
"""
If training hparams is invalid, this exception will be raised.
"""
class TtsTrainingWorkDirNotExistsException(TtsTrainingException):
"""
If training work dir not exists, this exception will be raised.
"""
class TtsTrainingCfgNotExistsException(TtsTrainingException):
"""
If training cfg not exists, this exception will be raised.
"""
class TtsTrainingDatasetInvalidException(TtsTrainingException):
"""
If dataset invalid, this exception will be raised.
"""
class TtsTrainingInvalidModelException(TtsTrainingException):
"""
If model is invalid or not exists, this exception will be raised.
"""

View File

@@ -1,22 +1,32 @@
easyasr>=0.0.2
espnet==202204
funasr>=0.1.4
greenlet>=1.1.2
h5py
inflect
jedi>=0.18.1
keras
kwsbp>=0.0.2
librosa
lxml
matplotlib
MinDAEC
msgpack>=1.0.4
nara_wpe
nltk
# tensorflow 1.15 requires numpy<=1.18
numpy<=1.18
parso>=0.8.3
pexpect>=4.8.0
pickleshare>=0.7.5
prompt-toolkit>=3.0.30
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
protobuf>3,<3.21.0
ptflops
ptyprocess>=0.7.0
py_sound_connect>=0.1
pygments>=2.12.0
pysptk>=0.1.15,<0.2.0
pytorch_wavelets
PyWavelets>=1.0.0
scikit-learn
@@ -24,5 +34,7 @@ SoundFile>0.10
sox
torchaudio
tqdm
traitlets>=5.3.0
ttsfrd>=0.0.3
unidecode
wcwidth>=0.2.5

View File

@@ -28,46 +28,69 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase,
self.task = Tasks.text_to_speech
self.zhcn_text = '今天北京天气怎么样'
self.en_text = 'How is the weather in Beijing?'
self.zhcn_voices = [
'zhitian_emo', 'zhizhe_emo', 'zhiyan_emo', 'zhibei_emo', 'zhcn'
]
self.zhcn_models = [
'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k',
'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k',
'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k',
'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k',
'damo/speech_sambert-hifigan_tts_zh-cn_16k'
]
self.en_voices = ['luca', 'luna', 'andy', 'annie', 'engb', 'enus']
self.en_models = [
'damo/speech_sambert-hifigan_tts_luca_en-gb_16k',
'damo/speech_sambert-hifigan_tts_luna_en-gb_16k',
'damo/speech_sambert-hifigan_tts_andy_en-us_16k',
'damo/speech_sambert-hifigan_tts_annie_en-us_16k',
'damo/speech_sambert-hifigan_tts_en-gb_16k',
'damo/speech_sambert-hifigan_tts_en-us_16k'
self.test_model_name = [
'pretrain_16k', 'pretrain_24k', 'zhitian_emo', 'zhizhe_emo',
'zhiyan_emo', 'zhibei_emo', 'zhcn_16k', 'luca', 'luna', 'andy',
'annie', 'engb_16k', 'enus_16k'
]
self.test_models = [{
'model':
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k',
'text': self.zhcn_text
}, {
'model':
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_24k',
'text': self.zhcn_text,
'sample_rate': 24000
}, {
'model': 'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k',
'text': self.zhcn_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k',
'text': self.zhcn_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k',
'text': self.zhcn_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k',
'text': self.zhcn_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_zh-cn_16k',
'text': self.zhcn_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_luca_en-gb_16k',
'text': self.en_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_luna_en-gb_16k',
'text': self.en_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_andy_en-us_16k',
'text': self.en_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_annie_en-us_16k',
'text': self.en_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_en-gb_16k',
'text': self.en_text
}, {
'model': 'damo/speech_sambert-hifigan_tts_en-us_16k',
'text': self.en_text
}]
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_pipeline(self):
for i in range(len(self.zhcn_voices)):
logger.info('test %s' % self.zhcn_voices[i])
for i in range(len(self.test_models)):
logger.info('test %s' % self.test_model_name[i])
sambert_hifigan_tts = pipeline(
task=self.task, model=self.zhcn_models[i])
task=self.task, model=self.test_models[i]['model'])
self.assertTrue(sambert_hifigan_tts is not None)
output = sambert_hifigan_tts(input=self.zhcn_text)
output = sambert_hifigan_tts(input=self.test_models[i]['text'])
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
pcm = output[OutputKeys.OUTPUT_PCM]
write('output_%s.wav' % self.zhcn_voices[i], 16000, pcm)
for i in range(len(self.en_voices)):
logger.info('test %s' % self.en_voices[i])
sambert_hifigan_tts = pipeline(
task=self.task, model=self.en_models[i])
self.assertTrue(sambert_hifigan_tts is not None)
output = sambert_hifigan_tts(input=self.en_text)
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
pcm = output[OutputKeys.OUTPUT_PCM]
write('output_%s.wav' % self.en_voices[i], 16000, pcm)
sr = 16000
if 'sample_rate' in self.test_models[i]:
sr = self.test_models[i]['sample_rate']
write('output_%s.wav' % self.test_model_name[i], sr, pcm)
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
def test_demo_compatibility(self):

View File

@@ -0,0 +1,66 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import os
import shutil
import tempfile
import unittest
from modelscope.metainfo import Trainers
from modelscope.msdatasets import MsDataset
from modelscope.trainers import build_trainer
from modelscope.utils.audio.audio_utils import TtsTrainType
from modelscope.utils.constant import DownloadMode, Fields, Tasks
from modelscope.utils.logger import get_logger
from modelscope.utils.test_utils import test_level
logger = get_logger()
class TestTtsTrainer(unittest.TestCase):
def setUp(self):
self.tmp_dir = tempfile.TemporaryDirectory().name
if not os.path.exists(self.tmp_dir):
os.makedirs(self.tmp_dir)
self.model_id = 'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k'
self.dataset_id = 'speech_kantts_opendata'
self.dataset_namespace = 'speech_tts'
self.train_info = {
TtsTrainType.TRAIN_TYPE_SAMBERT: {
'train_steps': 2,
'save_interval_steps': 1,
'eval_interval_steps': 1,
'log_interval': 1
},
TtsTrainType.TRAIN_TYPE_VOC: {
'train_steps': 2,
'save_interval_steps': 1,
'eval_interval_steps': 1,
'log_interval': 1
}
}
def tearDown(self):
shutil.rmtree(self.tmp_dir, ignore_errors=True)
super().tearDown()
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
def test_trainer(self):
kwargs = dict(
model=self.model_id,
work_dir=self.tmp_dir,
train_dataset=self.dataset_id,
train_dataset_namespace=self.dataset_namespace,
train_type=self.train_info)
trainer = build_trainer(
Trainers.speech_kantts_trainer, default_args=kwargs)
trainer.train()
tmp_am = os.path.join(self.tmp_dir, 'tmp_am', 'ckpt')
tmp_voc = os.path.join(self.tmp_dir, 'tmp_voc', 'ckpt')
assert os.path.exists(tmp_am)
assert os.path.exists(tmp_voc)
if __name__ == '__main__':
unittest.main()