mirror of
https://github.com/modelscope/modelscope.git
synced 2025-12-16 16:27:45 +01:00
[to #42322933] support kantts infer and finetune
Link: https://code.alibaba-inc.com/Ali-MaaS/MaaS-lib/codereview/11111331#tab=detail
This commit is contained in:
@@ -363,6 +363,7 @@ class Trainers(object):
|
||||
# audio trainers
|
||||
speech_frcrn_ans_cirm_16k = 'speech_frcrn_ans_cirm_16k'
|
||||
speech_dfsmn_kws_char_farfield = 'speech_dfsmn_kws_char_farfield'
|
||||
speech_kantts_trainer = 'speech-kantts-trainer'
|
||||
|
||||
|
||||
class Preprocessors(object):
|
||||
@@ -429,6 +430,7 @@ class Preprocessors(object):
|
||||
text_to_tacotron_symbols = 'text-to-tacotron-symbols'
|
||||
wav_to_lists = 'wav-to-lists'
|
||||
wav_to_scp = 'wav-to-scp'
|
||||
kantts_data_preprocessor = 'kantts-data-preprocessor'
|
||||
|
||||
# multi-modal preprocessor
|
||||
ofa_tasks_preprocessor = 'ofa-tasks-preprocessor'
|
||||
|
||||
@@ -7,12 +7,8 @@ if TYPE_CHECKING:
|
||||
from .sambert_hifi import SambertHifigan
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'sambert_hifi': ['SambertHifigan'],
|
||||
}
|
||||
|
||||
_import_structure = {'sambert_hifi': ['SambertHifigan']}
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
|
||||
8
modelscope/models/audio/tts/kantts/__init__.py
Normal file
8
modelscope/models/audio/tts/kantts/__init__.py
Normal file
@@ -0,0 +1,8 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .datasets.dataset import get_am_datasets, get_voc_datasets
|
||||
from .models import model_builder
|
||||
from .models.hifigan.hifigan import Generator
|
||||
from .train.loss import criterion_builder
|
||||
from .train.trainer import GAN_Trainer, Sambert_Trainer
|
||||
from .utils.ling_unit.ling_unit import KanTtsLinguisticUnit
|
||||
36
modelscope/models/audio/tts/kantts/datasets/data_types.py
Normal file
36
modelscope/models/audio/tts/kantts/datasets/data_types.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
|
||||
DATA_TYPE_DICT = {
|
||||
'txt': {
|
||||
'load_func': np.loadtxt,
|
||||
'desc': 'plain txt file or readable by np.loadtxt',
|
||||
},
|
||||
'wav': {
|
||||
'load_func': lambda x: wavfile.read(x)[1],
|
||||
'desc': 'wav file or readable by soundfile.read',
|
||||
},
|
||||
'npy': {
|
||||
'load_func': np.load,
|
||||
'desc': 'any .npy format file',
|
||||
},
|
||||
# PCM data type can be loaded by binary format
|
||||
'bin_f32': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.float32),
|
||||
'desc': 'binary file with float32 format',
|
||||
},
|
||||
'bin_f64': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.float64),
|
||||
'desc': 'binary file with float64 format',
|
||||
},
|
||||
'bin_i32': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.int32),
|
||||
'desc': 'binary file with int32 format',
|
||||
},
|
||||
'bin_i16': {
|
||||
'load_func': lambda x: np.fromfile(x, dtype=np.int16),
|
||||
'desc': 'binary file with int16 format',
|
||||
},
|
||||
}
|
||||
989
modelscope/models/audio/tts/kantts/datasets/dataset.py
Normal file
989
modelscope/models/audio/tts/kantts/datasets/dataset.py
Normal file
@@ -0,0 +1,989 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import functools
|
||||
import glob
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
from multiprocessing import Manager
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.stats import betabinom
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import (
|
||||
KanTtsLinguisticUnit, emotion_types)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
DATASET_RANDOM_SEED = 1234
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=256)
|
||||
def beta_binomial_prior_distribution(phoneme_count, mel_count, scaling=1.0):
|
||||
P = phoneme_count
|
||||
M = mel_count
|
||||
x = np.arange(0, P)
|
||||
mel_text_probs = []
|
||||
for i in range(1, M + 1):
|
||||
a, b = scaling * i, scaling * (M + 1 - i)
|
||||
rv = betabinom(P, a, b)
|
||||
mel_i_prob = rv.pmf(x)
|
||||
mel_text_probs.append(mel_i_prob)
|
||||
return torch.tensor(np.array(mel_text_probs))
|
||||
|
||||
|
||||
class Padder(object):
|
||||
|
||||
def __init__(self):
|
||||
super(Padder, self).__init__()
|
||||
pass
|
||||
|
||||
def _pad1D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
|
||||
|
||||
def _pad2D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, [(0, length - x.shape[0]), (0, 0)],
|
||||
mode='constant',
|
||||
constant_values=pad)
|
||||
|
||||
def _pad_durations(self, duration, max_in_len, max_out_len):
|
||||
framenum = np.sum(duration)
|
||||
symbolnum = duration.shape[0]
|
||||
if framenum < max_out_len:
|
||||
padframenum = max_out_len - framenum
|
||||
duration = np.insert(
|
||||
duration, symbolnum, values=padframenum, axis=0)
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum + 1,
|
||||
values=[0] * (max_in_len - symbolnum - 1),
|
||||
axis=0,
|
||||
)
|
||||
else:
|
||||
if symbolnum < max_in_len:
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum,
|
||||
values=[0] * (max_in_len - symbolnum),
|
||||
axis=0)
|
||||
return duration
|
||||
|
||||
def _round_up(self, x, multiple):
|
||||
remainder = x % multiple
|
||||
return x if remainder == 0 else x + multiple - remainder
|
||||
|
||||
def _prepare_scalar_inputs(self, inputs, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
|
||||
|
||||
def _prepare_targets(self, targets, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
|
||||
|
||||
def _prepare_durations(self, durations, max_in_len, max_out_len):
|
||||
return torch.from_numpy(
|
||||
np.stack([
|
||||
self._pad_durations(t, max_in_len, max_out_len)
|
||||
for t in durations
|
||||
])).long()
|
||||
|
||||
|
||||
class KanttsDataset(torch.utils.data.Dataset):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metafile,
|
||||
root_dir,
|
||||
):
|
||||
self.meta = []
|
||||
if not isinstance(metafile, list):
|
||||
metafile = [metafile]
|
||||
if not isinstance(root_dir, list):
|
||||
root_dir = [root_dir]
|
||||
|
||||
for meta_file, data_dir in zip(metafile, root_dir):
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error('meta file not found: {}'.format(meta_file))
|
||||
raise ValueError(
|
||||
'[Dataset] meta file: {} not found'.format(meta_file))
|
||||
if not os.path.exists(data_dir):
|
||||
logging.error('data directory not found: {}'.format(data_dir))
|
||||
raise ValueError(
|
||||
'[Dataset] data dir: {} not found'.format(data_dir))
|
||||
self.meta.extend(self.load_meta(meta_file, data_dir))
|
||||
|
||||
def load_meta(self, meta_file, data_dir):
|
||||
pass
|
||||
|
||||
|
||||
class VocDataset(KanttsDataset):
|
||||
"""
|
||||
provide (mel, audio) data pair
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metafile,
|
||||
root_dir,
|
||||
config,
|
||||
):
|
||||
self.config = config
|
||||
self.sampling_rate = config['audio_config']['sampling_rate']
|
||||
self.n_fft = config['audio_config']['n_fft']
|
||||
self.hop_length = config['audio_config']['hop_length']
|
||||
self.batch_max_steps = config['batch_max_steps']
|
||||
self.batch_max_frames = self.batch_max_steps // self.hop_length
|
||||
self.aux_context_window = 0
|
||||
self.start_offset = self.aux_context_window
|
||||
self.end_offset = -(self.batch_max_frames + self.aux_context_window)
|
||||
self.nsf_enable = (
|
||||
config['Model']['Generator']['params'].get('nsf_params', None)
|
||||
is not None)
|
||||
|
||||
super().__init__(metafile, root_dir)
|
||||
|
||||
# Load from training data directory
|
||||
if len(self.meta) == 0 and isinstance(root_dir, str):
|
||||
wav_dir = os.path.join(root_dir, 'wav')
|
||||
mel_dir = os.path.join(root_dir, 'mel')
|
||||
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
|
||||
raise ValueError('wav or mel directory not found')
|
||||
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
|
||||
elif len(self.meta) == 0 and isinstance(root_dir, list):
|
||||
for d in root_dir:
|
||||
wav_dir = os.path.join(d, 'wav')
|
||||
mel_dir = os.path.join(d, 'mel')
|
||||
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
|
||||
raise ValueError('wav or mel directory not found')
|
||||
self.meta.extend(self.load_meta_from_dir(wav_dir, mel_dir))
|
||||
|
||||
self.allow_cache = config['allow_cache']
|
||||
if self.allow_cache:
|
||||
self.manager = Manager()
|
||||
self.caches = self.manager.list()
|
||||
self.caches += [() for _ in range(len(self.meta))]
|
||||
|
||||
@staticmethod
|
||||
def gen_metafile(wav_dir, out_dir, split_ratio=0.98):
|
||||
wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
|
||||
frame_f0_dir = os.path.join(out_dir, 'frame_f0')
|
||||
frame_uv_dir = os.path.join(out_dir, 'frame_uv')
|
||||
mel_dir = os.path.join(out_dir, 'mel')
|
||||
random.seed(DATASET_RANDOM_SEED)
|
||||
random.shuffle(wav_files)
|
||||
num_train = int(len(wav_files) * split_ratio) - 1
|
||||
with open(os.path.join(out_dir, 'train.lst'), 'w') as f:
|
||||
for wav_file in wav_files[:num_train]:
|
||||
index = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
if (not os.path.exists(
|
||||
os.path.join(frame_f0_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(frame_uv_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(mel_dir, index + '.npy'))):
|
||||
continue
|
||||
f.write('{}\n'.format(index))
|
||||
|
||||
with open(os.path.join(out_dir, 'valid.lst'), 'w') as f:
|
||||
for wav_file in wav_files[num_train:]:
|
||||
index = os.path.splitext(os.path.basename(wav_file))[0]
|
||||
if (not os.path.exists(
|
||||
os.path.join(frame_f0_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(frame_uv_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(mel_dir, index + '.npy'))):
|
||||
continue
|
||||
f.write('{}\n'.format(index))
|
||||
|
||||
def load_meta(self, metafile, data_dir):
|
||||
with open(metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
wav_dir = os.path.join(data_dir, 'wav')
|
||||
mel_dir = os.path.join(data_dir, 'mel')
|
||||
frame_f0_dir = os.path.join(data_dir, 'frame_f0')
|
||||
frame_uv_dir = os.path.join(data_dir, 'frame_uv')
|
||||
if not os.path.exists(wav_dir) or not os.path.exists(mel_dir):
|
||||
raise ValueError('wav or mel directory not found')
|
||||
items = []
|
||||
logging.info('Loading metafile...')
|
||||
for name in tqdm(lines):
|
||||
name = name.strip()
|
||||
mel_file = os.path.join(mel_dir, name + '.npy')
|
||||
wav_file = os.path.join(wav_dir, name + '.wav')
|
||||
frame_f0_file = os.path.join(frame_f0_dir, name + '.npy')
|
||||
frame_uv_file = os.path.join(frame_uv_dir, name + '.npy')
|
||||
items.append((wav_file, mel_file, frame_f0_file, frame_uv_file))
|
||||
return items
|
||||
|
||||
def load_meta_from_dir(self, wav_dir, mel_dir):
|
||||
wav_files = glob.glob(os.path.join(wav_dir, '*.wav'))
|
||||
items = []
|
||||
for wav_file in wav_files:
|
||||
mel_file = os.path.join(mel_dir, os.path.basename(wav_file))
|
||||
if os.path.exists(mel_file):
|
||||
items.append((wav_file, mel_file))
|
||||
return items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.allow_cache and len(self.caches[idx]) != 0:
|
||||
return self.caches[idx]
|
||||
|
||||
wav_file, mel_file, frame_f0_file, frame_uv_file = self.meta[idx]
|
||||
|
||||
wav_data = librosa.core.load(wav_file, sr=self.sampling_rate)[0]
|
||||
mel_data = np.load(mel_file)
|
||||
|
||||
if self.nsf_enable:
|
||||
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
|
||||
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
|
||||
mel_data = np.concatenate((mel_data, frame_f0_data, frame_uv_data),
|
||||
axis=1)
|
||||
|
||||
# make sure the audio length and feature length are matched
|
||||
wav_data = np.pad(wav_data, (0, self.n_fft), mode='reflect')
|
||||
wav_data = wav_data[:len(mel_data) * self.hop_length]
|
||||
assert len(mel_data) * self.hop_length == len(wav_data)
|
||||
|
||||
if self.allow_cache:
|
||||
self.caches[idx] = (wav_data, mel_data)
|
||||
return (wav_data, mel_data)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
wav_data, mel_data = [item[0]
|
||||
for item in batch], [item[1] for item in batch]
|
||||
mel_lengths = [len(mel) for mel in mel_data]
|
||||
|
||||
start_frames = np.array([
|
||||
np.random.randint(self.start_offset, length + self.end_offset)
|
||||
for length in mel_lengths
|
||||
])
|
||||
|
||||
wav_start = start_frames * self.hop_length
|
||||
wav_end = wav_start + self.batch_max_steps
|
||||
|
||||
# aux window works as padding
|
||||
mel_start = start_frames - self.aux_context_window
|
||||
mel_end = mel_start + self.batch_max_frames + self.aux_context_window
|
||||
|
||||
wav_batch = [
|
||||
x[start:end] for x, start, end in zip(wav_data, wav_start, wav_end)
|
||||
]
|
||||
mel_batch = [
|
||||
c[start:end] for c, start, end in zip(mel_data, mel_start, mel_end)
|
||||
]
|
||||
|
||||
# (B, 1, T)
|
||||
wav_batch = torch.tensor(
|
||||
np.asarray(wav_batch), dtype=torch.float32).unsqueeze(1)
|
||||
# (B, C, T)
|
||||
mel_batch = torch.tensor(
|
||||
np.asarray(mel_batch), dtype=torch.float32).transpose(2, 1)
|
||||
return wav_batch, mel_batch
|
||||
|
||||
|
||||
def get_voc_datasets(
|
||||
config,
|
||||
root_dir,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
if isinstance(root_dir, str):
|
||||
root_dir = [root_dir]
|
||||
train_meta_lst = []
|
||||
valid_meta_lst = []
|
||||
for data_dir in root_dir:
|
||||
train_meta = os.path.join(data_dir, 'train.lst')
|
||||
valid_meta = os.path.join(data_dir, 'valid.lst')
|
||||
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
|
||||
VocDataset.gen_metafile(
|
||||
os.path.join(data_dir, 'wav'), data_dir, split_ratio)
|
||||
train_meta_lst.append(train_meta)
|
||||
valid_meta_lst.append(valid_meta)
|
||||
train_dataset = VocDataset(
|
||||
train_meta_lst,
|
||||
root_dir,
|
||||
config,
|
||||
)
|
||||
|
||||
valid_dataset = VocDataset(
|
||||
valid_meta_lst,
|
||||
root_dir,
|
||||
config,
|
||||
)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
def get_fp_label(aug_ling_txt):
|
||||
token_lst = aug_ling_txt.split(' ')
|
||||
emo_lst = [token.strip('{}').split('$')[4] for token in token_lst]
|
||||
syllable_lst = [token.strip('{}').split('$')[0] for token in token_lst]
|
||||
|
||||
# EOS token append
|
||||
emo_lst.append(emotion_types[0])
|
||||
syllable_lst.append('EOS')
|
||||
|
||||
# According to the original emotion tag, set each token's fp label.
|
||||
if emo_lst[0] != emotion_types[3]:
|
||||
emo_lst[0] = emotion_types[0]
|
||||
emo_lst[1] = emotion_types[0]
|
||||
for i in range(len(emo_lst) - 2, 1, -1):
|
||||
if emo_lst[i] != emotion_types[3] and emo_lst[i
|
||||
- 1] != emotion_types[3]:
|
||||
emo_lst[i] = emotion_types[0]
|
||||
elif emo_lst[i] != emotion_types[3] and emo_lst[
|
||||
i - 1] == emotion_types[3]:
|
||||
emo_lst[i] = emotion_types[3]
|
||||
if syllable_lst[i - 2] == 'ga':
|
||||
emo_lst[i + 1] = emotion_types[1]
|
||||
elif syllable_lst[i - 2] == 'ge' and syllable_lst[i - 1] == 'en_c':
|
||||
emo_lst[i + 1] = emotion_types[2]
|
||||
else:
|
||||
emo_lst[i + 1] = emotion_types[4]
|
||||
|
||||
fp_label = []
|
||||
for i in range(len(emo_lst)):
|
||||
if emo_lst[i] == emotion_types[0]:
|
||||
fp_label.append(0)
|
||||
elif emo_lst[i] == emotion_types[1]:
|
||||
fp_label.append(1)
|
||||
elif emo_lst[i] == emotion_types[2]:
|
||||
fp_label.append(2)
|
||||
elif emo_lst[i] == emotion_types[3]:
|
||||
continue
|
||||
elif emo_lst[i] == emotion_types[4]:
|
||||
fp_label.append(3)
|
||||
else:
|
||||
pass
|
||||
|
||||
return np.array(fp_label)
|
||||
|
||||
|
||||
class AmDataset(KanttsDataset):
|
||||
"""
|
||||
provide (ling, emo, speaker, mel) pair
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
metafile,
|
||||
root_dir,
|
||||
config,
|
||||
lang_dir=None,
|
||||
allow_cache=False,
|
||||
):
|
||||
self.config = config
|
||||
self.with_duration = True
|
||||
self.nsf_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
|
||||
'NSF', False)
|
||||
self.fp_enable = self.config['Model']['KanTtsSAMBERT']['params'].get(
|
||||
'FP', False)
|
||||
|
||||
super().__init__(metafile, root_dir)
|
||||
self.allow_cache = allow_cache
|
||||
|
||||
self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
|
||||
self.padder = Padder()
|
||||
|
||||
self.r = self.config['Model']['KanTtsSAMBERT']['params'][
|
||||
'outputs_per_step']
|
||||
|
||||
if allow_cache:
|
||||
self.manager = Manager()
|
||||
self.caches = self.manager.list()
|
||||
self.caches += [() for _ in range(len(self.meta))]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.allow_cache and len(self.caches[idx]) != 0:
|
||||
return self.caches[idx]
|
||||
|
||||
(
|
||||
ling_txt,
|
||||
mel_file,
|
||||
dur_file,
|
||||
f0_file,
|
||||
energy_file,
|
||||
frame_f0_file,
|
||||
frame_uv_file,
|
||||
aug_ling_txt,
|
||||
) = self.meta[idx]
|
||||
|
||||
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
|
||||
mel_data = np.load(mel_file)
|
||||
dur_data = np.load(dur_file) if dur_file is not None else None
|
||||
f0_data = np.load(f0_file)
|
||||
energy_data = np.load(energy_file)
|
||||
|
||||
# generate fp position label according to fpadd_meta
|
||||
if self.fp_enable and aug_ling_txt is not None:
|
||||
fp_label = get_fp_label(aug_ling_txt)
|
||||
else:
|
||||
fp_label = None
|
||||
|
||||
if self.with_duration:
|
||||
attn_prior = None
|
||||
else:
|
||||
attn_prior = beta_binomial_prior_distribution(
|
||||
len(ling_data[0]), mel_data.shape[0])
|
||||
|
||||
# Concat frame-level f0 and uv to mel_data
|
||||
if self.nsf_enable:
|
||||
frame_f0_data = np.load(frame_f0_file).reshape(-1, 1)
|
||||
frame_uv_data = np.load(frame_uv_file).reshape(-1, 1)
|
||||
mel_data = np.concatenate([mel_data, frame_f0_data, frame_uv_data],
|
||||
axis=1)
|
||||
|
||||
if self.allow_cache:
|
||||
self.caches[idx] = (
|
||||
ling_data,
|
||||
mel_data,
|
||||
dur_data,
|
||||
f0_data,
|
||||
energy_data,
|
||||
attn_prior,
|
||||
fp_label,
|
||||
)
|
||||
|
||||
return (
|
||||
ling_data,
|
||||
mel_data,
|
||||
dur_data,
|
||||
f0_data,
|
||||
energy_data,
|
||||
attn_prior,
|
||||
fp_label,
|
||||
)
|
||||
|
||||
def load_meta(self, metafile, data_dir):
|
||||
with open(metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
aug_ling_dict = {}
|
||||
if self.fp_enable:
|
||||
add_fp_metafile = metafile.replace('fprm', 'fpadd')
|
||||
with open(add_fp_metafile, 'r') as f:
|
||||
fpadd_lines = f.readlines()
|
||||
for line in fpadd_lines:
|
||||
index, aug_ling_txt = line.split('\t')
|
||||
aug_ling_dict[index] = aug_ling_txt
|
||||
|
||||
mel_dir = os.path.join(data_dir, 'mel')
|
||||
dur_dir = os.path.join(data_dir, 'duration')
|
||||
f0_dir = os.path.join(data_dir, 'f0')
|
||||
energy_dir = os.path.join(data_dir, 'energy')
|
||||
frame_f0_dir = os.path.join(data_dir, 'frame_f0')
|
||||
frame_uv_dir = os.path.join(data_dir, 'frame_uv')
|
||||
|
||||
self.with_duration = os.path.exists(dur_dir)
|
||||
|
||||
items = []
|
||||
logging.info('Loading metafile...')
|
||||
for line in tqdm(lines):
|
||||
line = line.strip()
|
||||
index, ling_txt = line.split('\t')
|
||||
mel_file = os.path.join(mel_dir, index + '.npy')
|
||||
if self.with_duration:
|
||||
dur_file = os.path.join(dur_dir, index + '.npy')
|
||||
else:
|
||||
dur_file = None
|
||||
f0_file = os.path.join(f0_dir, index + '.npy')
|
||||
energy_file = os.path.join(energy_dir, index + '.npy')
|
||||
frame_f0_file = os.path.join(frame_f0_dir, index + '.npy')
|
||||
frame_uv_file = os.path.join(frame_uv_dir, index + '.npy')
|
||||
aug_ling_txt = aug_ling_dict.get(index, None)
|
||||
if self.fp_enable and aug_ling_txt is None:
|
||||
logging.warning(f'Missing fpadd meta for {index}')
|
||||
continue
|
||||
|
||||
items.append((
|
||||
ling_txt,
|
||||
mel_file,
|
||||
dur_file,
|
||||
f0_file,
|
||||
energy_file,
|
||||
frame_f0_file,
|
||||
frame_uv_file,
|
||||
aug_ling_txt,
|
||||
))
|
||||
|
||||
return items
|
||||
|
||||
def load_fpadd_meta(self, metafile):
|
||||
with open(metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = []
|
||||
logging.info('Loading fpadd metafile...')
|
||||
for line in tqdm(lines):
|
||||
line = line.strip()
|
||||
index, ling_txt = line.split('\t')
|
||||
|
||||
items.append((ling_txt, ))
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def gen_metafile(
|
||||
raw_meta_file,
|
||||
out_dir,
|
||||
train_meta_file,
|
||||
valid_meta_file,
|
||||
badlist=None,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
with open(raw_meta_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
frame_f0_dir = os.path.join(out_dir, 'frame_f0')
|
||||
frame_uv_dir = os.path.join(out_dir, 'frame_uv')
|
||||
mel_dir = os.path.join(out_dir, 'mel')
|
||||
duration_dir = os.path.join(out_dir, 'duration')
|
||||
random.seed(DATASET_RANDOM_SEED)
|
||||
random.shuffle(lines)
|
||||
num_train = int(len(lines) * split_ratio) - 1
|
||||
with open(train_meta_file, 'w') as f:
|
||||
for line in lines[:num_train]:
|
||||
index = line.split('\t')[0]
|
||||
if badlist is not None and index in badlist:
|
||||
continue
|
||||
if (not os.path.exists(
|
||||
os.path.join(frame_f0_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(frame_uv_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(duration_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(mel_dir, index + '.npy'))):
|
||||
continue
|
||||
f.write(line)
|
||||
|
||||
with open(valid_meta_file, 'w') as f:
|
||||
for line in lines[num_train:]:
|
||||
index = line.split('\t')[0]
|
||||
if badlist is not None and index in badlist:
|
||||
continue
|
||||
if (not os.path.exists(
|
||||
os.path.join(frame_f0_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(frame_uv_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(duration_dir, index + '.npy'))
|
||||
or not os.path.exists(
|
||||
os.path.join(mel_dir, index + '.npy'))):
|
||||
continue
|
||||
f.write(line)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
data_dict = {}
|
||||
|
||||
max_input_length = max((len(x[0][0]) for x in batch))
|
||||
max_dur_length = max((x[2].shape[0] for x in batch)) + 1
|
||||
|
||||
# pure linguistic info: sy|tone|syllable_flag|word_segment
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[0]
|
||||
inputs_sy = self.padder._prepare_scalar_inputs(
|
||||
[x[0][0] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
# tone
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[1]
|
||||
inputs_tone = self.padder._prepare_scalar_inputs(
|
||||
[x[0][1] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# syllable_flag
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[2]
|
||||
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
|
||||
[x[0][2] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# word_segment
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[3]
|
||||
inputs_ws = self.padder._prepare_scalar_inputs(
|
||||
[x[0][3] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# emotion category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[4]
|
||||
data_dict['input_emotions'] = self.padder._prepare_scalar_inputs(
|
||||
[x[0][4] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# speaker category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[5]
|
||||
data_dict['input_speakers'] = self.padder._prepare_scalar_inputs(
|
||||
[x[0][5] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# fp label category
|
||||
if self.fp_enable:
|
||||
data_dict['fp_label'] = self.padder._prepare_scalar_inputs(
|
||||
[x[6] for x in batch],
|
||||
max_input_length,
|
||||
0,
|
||||
).long()
|
||||
|
||||
data_dict['input_lings'] = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
|
||||
data_dict['valid_input_lengths'] = torch.as_tensor(
|
||||
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
|
||||
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
|
||||
data_dict['valid_output_lengths'] = torch.as_tensor(
|
||||
[len(x[1]) for x in batch], dtype=torch.long)
|
||||
|
||||
max_output_length = torch.max(data_dict['valid_output_lengths']).item()
|
||||
max_output_round_length = self.padder._round_up(
|
||||
max_output_length, self.r)
|
||||
|
||||
data_dict['mel_targets'] = self.padder._prepare_targets(
|
||||
[x[1] for x in batch], max_output_round_length, 0.0)
|
||||
if self.with_duration:
|
||||
data_dict['durations'] = self.padder._prepare_durations(
|
||||
[x[2] for x in batch], max_dur_length, max_output_round_length)
|
||||
else:
|
||||
data_dict['durations'] = None
|
||||
|
||||
if self.with_duration:
|
||||
if self.fp_enable:
|
||||
feats_padding_length = max_dur_length
|
||||
else:
|
||||
feats_padding_length = max_input_length
|
||||
else:
|
||||
feats_padding_length = max_output_round_length
|
||||
|
||||
data_dict['pitch_contours'] = self.padder._prepare_scalar_inputs(
|
||||
[x[3] for x in batch], feats_padding_length, 0.0).float()
|
||||
data_dict['energy_contours'] = self.padder._prepare_scalar_inputs(
|
||||
[x[4] for x in batch], feats_padding_length, 0.0).float()
|
||||
|
||||
if self.with_duration:
|
||||
data_dict['attn_priors'] = None
|
||||
else:
|
||||
data_dict['attn_priors'] = torch.zeros(
|
||||
len(batch), max_output_round_length, max_input_length)
|
||||
for i in range(len(batch)):
|
||||
attn_prior = batch[i][5]
|
||||
data_dict['attn_priors'][
|
||||
i, :attn_prior.shape[0], :attn_prior.shape[1]] = attn_prior
|
||||
return data_dict
|
||||
|
||||
|
||||
def get_am_datasets(
|
||||
metafile,
|
||||
root_dir,
|
||||
lang_dir,
|
||||
config,
|
||||
allow_cache,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
if not isinstance(root_dir, list):
|
||||
root_dir = [root_dir]
|
||||
if not isinstance(metafile, list):
|
||||
metafile = [metafile]
|
||||
|
||||
train_meta_lst = []
|
||||
valid_meta_lst = []
|
||||
|
||||
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
|
||||
|
||||
if fp_enable:
|
||||
am_train_fn = 'am_fprm_train.lst'
|
||||
am_valid_fn = 'am_fprm_valid.lst'
|
||||
else:
|
||||
am_train_fn = 'am_train.lst'
|
||||
am_valid_fn = 'am_valid.lst'
|
||||
|
||||
for raw_metafile, data_dir in zip(metafile, root_dir):
|
||||
train_meta = os.path.join(data_dir, am_train_fn)
|
||||
valid_meta = os.path.join(data_dir, am_valid_fn)
|
||||
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
|
||||
AmDataset.gen_metafile(raw_metafile, data_dir, train_meta,
|
||||
valid_meta, split_ratio)
|
||||
train_meta_lst.append(train_meta)
|
||||
valid_meta_lst.append(valid_meta)
|
||||
|
||||
train_dataset = AmDataset(train_meta_lst, root_dir, config, lang_dir,
|
||||
allow_cache)
|
||||
|
||||
valid_dataset = AmDataset(valid_meta_lst, root_dir, config, lang_dir,
|
||||
allow_cache)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
|
||||
|
||||
class MaskingActor(object):
|
||||
|
||||
def __init__(self, mask_ratio=0.15):
|
||||
super(MaskingActor, self).__init__()
|
||||
self.mask_ratio = mask_ratio
|
||||
pass
|
||||
|
||||
def _get_random_mask(self, length, p1=0.15):
|
||||
mask = np.random.uniform(0, 1, length)
|
||||
index = 0
|
||||
while index < len(mask):
|
||||
if mask[index] < p1:
|
||||
mask[index] = 1
|
||||
else:
|
||||
mask[index] = 0
|
||||
index += 1
|
||||
|
||||
return mask
|
||||
|
||||
def _input_bert_masking(
|
||||
self,
|
||||
sequence_array,
|
||||
nb_symbol_category,
|
||||
mask_symbol_id,
|
||||
mask,
|
||||
p2=0.8,
|
||||
p3=0.1,
|
||||
p4=0.1,
|
||||
):
|
||||
sequence_array_mask = sequence_array.copy()
|
||||
mask_id = np.where(mask == 1)[0]
|
||||
mask_len = len(mask_id)
|
||||
rand = np.arange(mask_len)
|
||||
np.random.shuffle(rand)
|
||||
|
||||
# [MASK]
|
||||
mask_id_p2 = mask_id[rand[0:int(math.floor(mask_len * p2))]]
|
||||
if len(mask_id_p2) > 0:
|
||||
sequence_array_mask[mask_id_p2] = mask_symbol_id
|
||||
|
||||
# rand
|
||||
mask_id_p3 = mask_id[
|
||||
rand[int(math.floor(mask_len * p2)):int(math.floor(mask_len * p2))
|
||||
+ int(math.floor(mask_len * p3))]]
|
||||
if len(mask_id_p3) > 0:
|
||||
sequence_array_mask[mask_id_p3] = random.randint(
|
||||
0, nb_symbol_category - 1)
|
||||
|
||||
# ori
|
||||
# do nothing
|
||||
|
||||
return sequence_array_mask
|
||||
|
||||
|
||||
class BERTTextDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
provide (ling, ling_sy_masked, bert_mask) pair
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
metafile,
|
||||
root_dir,
|
||||
lang_dir=None,
|
||||
allow_cache=False,
|
||||
):
|
||||
self.meta = []
|
||||
self.config = config
|
||||
|
||||
if not isinstance(metafile, list):
|
||||
metafile = [metafile]
|
||||
if not isinstance(root_dir, list):
|
||||
root_dir = [root_dir]
|
||||
|
||||
for meta_file, data_dir in zip(metafile, root_dir):
|
||||
if not os.path.exists(meta_file):
|
||||
logging.error('meta file not found: {}'.format(meta_file))
|
||||
raise ValueError(
|
||||
'[BERT_Text_Dataset] meta file: {} not found'.format(
|
||||
meta_file))
|
||||
if not os.path.exists(data_dir):
|
||||
logging.error('data dir not found: {}'.format(data_dir))
|
||||
raise ValueError(
|
||||
'[BERT_Text_Dataset] data dir: {} not found'.format(
|
||||
data_dir))
|
||||
self.meta.extend(self.load_meta(meta_file, data_dir))
|
||||
|
||||
self.allow_cache = allow_cache
|
||||
|
||||
self.ling_unit = KanTtsLinguisticUnit(config, lang_dir)
|
||||
self.padder = Padder()
|
||||
self.masking_actor = MaskingActor(
|
||||
self.config['Model']['KanTtsTextsyBERT']['params']['mask_ratio'])
|
||||
|
||||
if allow_cache:
|
||||
self.manager = Manager()
|
||||
self.caches = self.manager.list()
|
||||
self.caches += [() for _ in range(len(self.meta))]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.meta)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if self.allow_cache and len(self.caches[idx]) != 0:
|
||||
ling_data = self.caches[idx][0]
|
||||
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
|
||||
return (ling_data, ling_sy_masked_data, bert_mask)
|
||||
|
||||
ling_txt = self.meta[idx]
|
||||
|
||||
ling_data = self.ling_unit.encode_symbol_sequence(ling_txt)
|
||||
bert_mask, ling_sy_masked_data = self.bert_masking(ling_data)
|
||||
|
||||
if self.allow_cache:
|
||||
self.caches[idx] = (ling_data, )
|
||||
|
||||
return (ling_data, ling_sy_masked_data, bert_mask)
|
||||
|
||||
def load_meta(self, metafile, data_dir):
|
||||
with open(metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
items = []
|
||||
logging.info('Loading metafile...')
|
||||
for line in tqdm(lines):
|
||||
line = line.strip()
|
||||
index, ling_txt = line.split('\t')
|
||||
|
||||
items.append((ling_txt))
|
||||
|
||||
return items
|
||||
|
||||
@staticmethod
|
||||
def gen_metafile(raw_meta_file, out_dir, split_ratio=0.98):
|
||||
with open(raw_meta_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
random.seed(DATASET_RANDOM_SEED)
|
||||
random.shuffle(lines)
|
||||
num_train = int(len(lines) * split_ratio) - 1
|
||||
with open(os.path.join(out_dir, 'bert_train.lst'), 'w') as f:
|
||||
for line in lines[:num_train]:
|
||||
f.write(line)
|
||||
|
||||
with open(os.path.join(out_dir, 'bert_valid.lst'), 'w') as f:
|
||||
for line in lines[num_train:]:
|
||||
f.write(line)
|
||||
|
||||
def bert_masking(self, ling_data):
|
||||
length = len(ling_data[0])
|
||||
mask = self.masking_actor._get_random_mask(
|
||||
length, p1=self.masking_actor.mask_ratio)
|
||||
mask[-1] = 0
|
||||
|
||||
# sy_masked
|
||||
sy_mask_symbol_id = self.ling_unit.encode_sy([self.ling_unit._mask])[0]
|
||||
ling_sy_masked_data = self.masking_actor._input_bert_masking(
|
||||
ling_data[0],
|
||||
self.ling_unit.get_unit_size()['sy'],
|
||||
sy_mask_symbol_id,
|
||||
mask,
|
||||
p2=0.8,
|
||||
p3=0.1,
|
||||
p4=0.1,
|
||||
)
|
||||
|
||||
return (mask, ling_sy_masked_data)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
data_dict = {}
|
||||
|
||||
max_input_length = max((len(x[0][0]) for x in batch))
|
||||
|
||||
# pure linguistic info: sy|tone|syllable_flag|word_segment
|
||||
# sy
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[0]
|
||||
targets_sy = self.padder._prepare_scalar_inputs(
|
||||
[x[0][0] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
# sy masked
|
||||
inputs_sy = self.padder._prepare_scalar_inputs(
|
||||
[x[1] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
# tone
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[1]
|
||||
inputs_tone = self.padder._prepare_scalar_inputs(
|
||||
[x[0][1] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# syllable_flag
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[2]
|
||||
inputs_syllable_flag = self.padder._prepare_scalar_inputs(
|
||||
[x[0][2] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
# word_segment
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[3]
|
||||
inputs_ws = self.padder._prepare_scalar_inputs(
|
||||
[x[0][3] for x in batch],
|
||||
max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type],
|
||||
).long()
|
||||
|
||||
data_dict['input_lings'] = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
|
||||
data_dict['valid_input_lengths'] = torch.as_tensor(
|
||||
[len(x[0][0]) - 1 for x in batch], dtype=torch.long
|
||||
) # 输入的symbol sequence会在后面拼一个“~”,影响duration计算,所以把length-1
|
||||
|
||||
data_dict['targets'] = targets_sy
|
||||
data_dict['bert_masks'] = self.padder._prepare_scalar_inputs(
|
||||
[x[2] for x in batch], max_input_length, 0.0)
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
def get_bert_text_datasets(
|
||||
metafile,
|
||||
root_dir,
|
||||
config,
|
||||
allow_cache,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
if not isinstance(root_dir, list):
|
||||
root_dir = [root_dir]
|
||||
if not isinstance(metafile, list):
|
||||
metafile = [metafile]
|
||||
|
||||
train_meta_lst = []
|
||||
valid_meta_lst = []
|
||||
|
||||
for raw_metafile, data_dir in zip(metafile, root_dir):
|
||||
train_meta = os.path.join(data_dir, 'bert_train.lst')
|
||||
valid_meta = os.path.join(data_dir, 'bert_valid.lst')
|
||||
if not os.path.exists(train_meta) or not os.path.exists(valid_meta):
|
||||
BERTTextDataset.gen_metafile(raw_metafile, data_dir, split_ratio)
|
||||
train_meta_lst.append(train_meta)
|
||||
valid_meta_lst.append(valid_meta)
|
||||
|
||||
train_dataset = BERTTextDataset(config, train_meta_lst, root_dir,
|
||||
allow_cache)
|
||||
|
||||
valid_dataset = BERTTextDataset(config, valid_meta_lst, root_dir,
|
||||
allow_cache)
|
||||
|
||||
return train_dataset, valid_dataset
|
||||
158
modelscope/models/audio/tts/kantts/models/__init__.py
Normal file
158
modelscope/models/audio/tts/kantts/models/__init__.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
from torch.nn.parallel import DistributedDataParallel
|
||||
|
||||
import modelscope.models.audio.tts.kantts.train.scheduler as kantts_scheduler
|
||||
from modelscope.models.audio.tts.kantts.utils.ling_unit.ling_unit import \
|
||||
get_fpdict
|
||||
from .hifigan import (Generator, MultiPeriodDiscriminator,
|
||||
MultiScaleDiscriminator, MultiSpecDiscriminator)
|
||||
from .pqmf import PQMF
|
||||
from .sambert.kantts_sambert import KanTtsSAMBERT, KanTtsTextsyBERT
|
||||
|
||||
|
||||
def optimizer_builder(model_params, opt_name, opt_params):
|
||||
opt_cls = getattr(torch.optim, opt_name)
|
||||
optimizer = opt_cls(model_params, **opt_params)
|
||||
return optimizer
|
||||
|
||||
|
||||
def scheduler_builder(optimizer, sche_name, sche_params):
|
||||
scheduler_cls = getattr(kantts_scheduler, sche_name)
|
||||
scheduler = scheduler_cls(optimizer, **sche_params)
|
||||
return scheduler
|
||||
|
||||
|
||||
def hifigan_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
model['discriminator'] = {}
|
||||
optimizer['discriminator'] = {}
|
||||
scheduler['discriminator'] = {}
|
||||
for model_name in config['Model'].keys():
|
||||
if model_name == 'Generator':
|
||||
params = config['Model'][model_name]['params']
|
||||
model['generator'] = Generator(**params).to(device)
|
||||
optimizer['generator'] = optimizer_builder(
|
||||
model['generator'].parameters(),
|
||||
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
|
||||
config['Model'][model_name]['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['generator'] = scheduler_builder(
|
||||
optimizer['generator'],
|
||||
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
|
||||
config['Model'][model_name]['scheduler'].get('params', {}),
|
||||
)
|
||||
else:
|
||||
params = config['Model'][model_name]['params']
|
||||
model['discriminator'][model_name] = globals()[model_name](
|
||||
**params).to(device)
|
||||
optimizer['discriminator'][model_name] = optimizer_builder(
|
||||
model['discriminator'][model_name].parameters(),
|
||||
config['Model'][model_name]['optimizer'].get('type', 'Adam'),
|
||||
config['Model'][model_name]['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['discriminator'][model_name] = scheduler_builder(
|
||||
optimizer['discriminator'][model_name],
|
||||
config['Model'][model_name]['scheduler'].get('type', 'StepLR'),
|
||||
config['Model'][model_name]['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
out_channels = config['Model']['Generator']['params']['out_channels']
|
||||
if out_channels > 1:
|
||||
model['pqmf'] = PQMF(
|
||||
subbands=out_channels, **config.get('pqmf', {})).to(device)
|
||||
|
||||
# FIXME: pywavelets buffer leads to gradient error in DDP training
|
||||
# Solution: https://github.com/pytorch/pytorch/issues/22095
|
||||
if distributed:
|
||||
model['generator'] = DistributedDataParallel(
|
||||
model['generator'],
|
||||
device_ids=[rank],
|
||||
output_device=rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
for model_name in model['discriminator'].keys():
|
||||
model['discriminator'][model_name] = DistributedDataParallel(
|
||||
model['discriminator'][model_name],
|
||||
device_ids=[rank],
|
||||
output_device=rank,
|
||||
broadcast_buffers=False,
|
||||
)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
def sambert_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
|
||||
model['KanTtsSAMBERT'] = KanTtsSAMBERT(
|
||||
config['Model']['KanTtsSAMBERT']['params']).to(device)
|
||||
|
||||
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
|
||||
if fp_enable:
|
||||
fp_dict = {
|
||||
k: torch.from_numpy(v).long().unsqueeze(0).to(device)
|
||||
for k, v in get_fpdict(config).items()
|
||||
}
|
||||
model['KanTtsSAMBERT'].fp_dict = fp_dict
|
||||
|
||||
optimizer['KanTtsSAMBERT'] = optimizer_builder(
|
||||
model['KanTtsSAMBERT'].parameters(),
|
||||
config['Model']['KanTtsSAMBERT']['optimizer'].get('type', 'Adam'),
|
||||
config['Model']['KanTtsSAMBERT']['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['KanTtsSAMBERT'] = scheduler_builder(
|
||||
optimizer['KanTtsSAMBERT'],
|
||||
config['Model']['KanTtsSAMBERT']['scheduler'].get('type', 'StepLR'),
|
||||
config['Model']['KanTtsSAMBERT']['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
if distributed:
|
||||
model['KanTtsSAMBERT'] = DistributedDataParallel(
|
||||
model['KanTtsSAMBERT'], device_ids=[rank], output_device=rank)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
def sybert_model_builder(config, device, rank, distributed):
|
||||
model = {}
|
||||
optimizer = {}
|
||||
scheduler = {}
|
||||
|
||||
model['KanTtsTextsyBERT'] = KanTtsTextsyBERT(
|
||||
config['Model']['KanTtsTextsyBERT']['params']).to(device)
|
||||
optimizer['KanTtsTextsyBERT'] = optimizer_builder(
|
||||
model['KanTtsTextsyBERT'].parameters(),
|
||||
config['Model']['KanTtsTextsyBERT']['optimizer'].get('type', 'Adam'),
|
||||
config['Model']['KanTtsTextsyBERT']['optimizer'].get('params', {}),
|
||||
)
|
||||
scheduler['KanTtsTextsyBERT'] = scheduler_builder(
|
||||
optimizer['KanTtsTextsyBERT'],
|
||||
config['Model']['KanTtsTextsyBERT']['scheduler'].get('type', 'StepLR'),
|
||||
config['Model']['KanTtsTextsyBERT']['scheduler'].get('params', {}),
|
||||
)
|
||||
|
||||
if distributed:
|
||||
model['KanTtsTextsyBERT'] = DistributedDataParallel(
|
||||
model['KanTtsTextsyBERT'], device_ids=[rank], output_device=rank)
|
||||
|
||||
return model, optimizer, scheduler
|
||||
|
||||
|
||||
model_dict = {
|
||||
'hifigan': hifigan_model_builder,
|
||||
'sambert': sambert_model_builder,
|
||||
'sybert': sybert_model_builder,
|
||||
}
|
||||
|
||||
|
||||
def model_builder(config, device='cpu', rank=0, distributed=False):
|
||||
builder_func = model_dict[config['model_type']]
|
||||
model, optimizer, scheduler = builder_func(config, device, rank,
|
||||
distributed)
|
||||
return model, optimizer, scheduler
|
||||
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .hifigan import (Generator, MultiPeriodDiscriminator,
|
||||
MultiScaleDiscriminator, MultiSpecDiscriminator)
|
||||
613
modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py
Normal file
613
modelscope/models/audio/tts/kantts/models/hifigan/hifigan.py
Normal file
@@ -0,0 +1,613 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import copy
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from pytorch_wavelets import DWT1DForward
|
||||
from torch.nn.utils import spectral_norm, weight_norm
|
||||
|
||||
from modelscope.models.audio.tts.kantts.utils.audio_torch import stft
|
||||
from .layers import (CausalConv1d, CausalConvTranspose1d, Conv1d,
|
||||
ConvTranspose1d, ResidualBlock, SourceModule)
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=80,
|
||||
out_channels=1,
|
||||
channels=512,
|
||||
kernel_size=7,
|
||||
upsample_scales=(8, 8, 2, 2),
|
||||
upsample_kernal_sizes=(16, 16, 4, 4),
|
||||
resblock_kernel_sizes=(3, 7, 11),
|
||||
resblock_dilations=[(1, 3, 5), (1, 3, 5), (1, 3, 5)],
|
||||
repeat_upsample=True,
|
||||
bias=True,
|
||||
causal=True,
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_weight_norm=True,
|
||||
nsf_params=None,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
|
||||
# check hyperparameters are valid
|
||||
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
|
||||
assert len(upsample_scales) == len(upsample_kernal_sizes)
|
||||
assert len(resblock_dilations) == len(resblock_kernel_sizes)
|
||||
|
||||
self.upsample_scales = upsample_scales
|
||||
self.repeat_upsample = repeat_upsample
|
||||
self.num_upsamples = len(upsample_kernal_sizes)
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.out_channels = out_channels
|
||||
self.nsf_enable = nsf_params is not None
|
||||
|
||||
self.transpose_upsamples = torch.nn.ModuleList()
|
||||
self.repeat_upsamples = torch.nn.ModuleList() # for repeat upsampling
|
||||
self.conv_blocks = torch.nn.ModuleList()
|
||||
|
||||
conv_cls = CausalConv1d if causal else Conv1d
|
||||
conv_transposed_cls = CausalConvTranspose1d if causal else ConvTranspose1d
|
||||
|
||||
self.conv_pre = conv_cls(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2)
|
||||
|
||||
for i in range(len(upsample_kernal_sizes)):
|
||||
self.transpose_upsamples.append(
|
||||
torch.nn.Sequential(
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
conv_transposed_cls(
|
||||
channels // (2**i),
|
||||
channels // (2**(i + 1)),
|
||||
upsample_kernal_sizes[i],
|
||||
upsample_scales[i],
|
||||
padding=(upsample_kernal_sizes[i] - upsample_scales[i])
|
||||
// 2,
|
||||
),
|
||||
))
|
||||
|
||||
if repeat_upsample:
|
||||
self.repeat_upsamples.append(
|
||||
nn.Sequential(
|
||||
nn.Upsample(
|
||||
mode='nearest', scale_factor=upsample_scales[i]),
|
||||
getattr(torch.nn, nonlinear_activation)(
|
||||
**nonlinear_activation_params),
|
||||
conv_cls(
|
||||
channels // (2**i),
|
||||
channels // (2**(i + 1)),
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
),
|
||||
))
|
||||
|
||||
for j in range(len(resblock_kernel_sizes)):
|
||||
self.conv_blocks.append(
|
||||
ResidualBlock(
|
||||
channels=channels // (2**(i + 1)),
|
||||
kernel_size=resblock_kernel_sizes[j],
|
||||
dilation=resblock_dilations[j],
|
||||
nonlinear_activation=nonlinear_activation,
|
||||
nonlinear_activation_params=nonlinear_activation_params,
|
||||
causal=causal,
|
||||
))
|
||||
|
||||
self.conv_post = conv_cls(
|
||||
channels // (2**(i + 1)),
|
||||
out_channels,
|
||||
kernel_size,
|
||||
1,
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)
|
||||
|
||||
if self.nsf_enable:
|
||||
self.source_module = SourceModule(
|
||||
nb_harmonics=nsf_params['nb_harmonics'],
|
||||
upsample_ratio=np.cumprod(self.upsample_scales)[-1],
|
||||
sampling_rate=nsf_params['sampling_rate'],
|
||||
)
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.downsample_rates = [1] + self.upsample_scales[::-1][:-1]
|
||||
self.downsample_cum_rates = np.cumprod(self.downsample_rates)
|
||||
|
||||
for i, u in enumerate(self.downsample_cum_rates[::-1]):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
Conv1d(1, channels // (2**(i + 1)), 1, 1))
|
||||
else:
|
||||
self.source_downs.append(
|
||||
conv_cls(
|
||||
1,
|
||||
channels // (2**(i + 1)),
|
||||
u * 2,
|
||||
u,
|
||||
padding=u // 2,
|
||||
))
|
||||
|
||||
def forward(self, x):
|
||||
if self.nsf_enable:
|
||||
mel = x[:, :-2, :]
|
||||
pitch = x[:, -2:-1, :]
|
||||
uv = x[:, -1:, :]
|
||||
excitation = self.source_module(pitch, uv)
|
||||
else:
|
||||
mel = x
|
||||
|
||||
x = self.conv_pre(mel)
|
||||
for i in range(self.num_upsamples):
|
||||
# FIXME: sin function here seems to be causing issues
|
||||
x = torch.sin(x) + x
|
||||
rep = self.repeat_upsamples[i](x)
|
||||
|
||||
if self.nsf_enable:
|
||||
# Downsampling the excitation signal
|
||||
e = self.source_downs[i](excitation)
|
||||
# augment inputs with the excitation
|
||||
x = rep + e
|
||||
else:
|
||||
# transconv
|
||||
up = self.transpose_upsamples[i](x)
|
||||
x = rep + up
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.conv_blocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.conv_blocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for layer in self.transpose_upsamples:
|
||||
layer[-1].remove_weight_norm()
|
||||
for layer in self.repeat_upsamples:
|
||||
layer[-1].remove_weight_norm()
|
||||
for layer in self.conv_blocks:
|
||||
layer.remove_weight_norm()
|
||||
self.conv_pre.remove_weight_norm()
|
||||
self.conv_post.remove_weight_norm()
|
||||
if self.nsf_enable:
|
||||
self.source_module.remove_weight_norm()
|
||||
for layer in self.source_downs:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
class PeriodDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
period=3,
|
||||
kernel_sizes=[5, 3],
|
||||
channels=32,
|
||||
downsample_scales=[3, 3, 3, 3, 1],
|
||||
max_downsample_channels=1024,
|
||||
bias=True,
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_spectral_norm=False,
|
||||
):
|
||||
super(PeriodDiscriminator, self).__init__()
|
||||
self.period = period
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
self.convs = nn.ModuleList()
|
||||
in_chs, out_chs = in_channels, channels
|
||||
|
||||
for downsample_scale in downsample_scales:
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
(kernel_sizes[0], 1),
|
||||
(downsample_scale, 1),
|
||||
padding=((kernel_sizes[0] - 1) // 2, 0),
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = out_chs
|
||||
out_chs = min(out_chs * 4, max_downsample_channels)
|
||||
|
||||
self.conv_post = nn.Conv2d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
(kernel_sizes[1] - 1, 1),
|
||||
1,
|
||||
padding=((kernel_sizes[1] - 1) // 2, 0),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
|
||||
# 1d to 2d
|
||||
b, c, t = x.shape
|
||||
if t % self.period != 0: # pad first
|
||||
n_pad = self.period - (t % self.period)
|
||||
x = F.pad(x, (0, n_pad), 'reflect')
|
||||
t = t + n_pad
|
||||
x = x.view(b, c, t // self.period, self.period)
|
||||
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
periods=[2, 3, 5, 7, 11],
|
||||
discriminator_params={
|
||||
'in_channels': 1,
|
||||
'out_channels': 1,
|
||||
'kernel_sizes': [5, 3],
|
||||
'channels': 32,
|
||||
'downsample_scales': [3, 3, 3, 3, 1],
|
||||
'max_downsample_channels': 1024,
|
||||
'bias': True,
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
'use_spectral_norm': False,
|
||||
},
|
||||
):
|
||||
super(MultiPeriodDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList()
|
||||
for period in periods:
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
params['period'] = period
|
||||
self.discriminators += [PeriodDiscriminator(**params)]
|
||||
|
||||
def forward(self, y):
|
||||
y_d_rs = []
|
||||
fmap_rs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
|
||||
return y_d_rs, fmap_rs
|
||||
|
||||
|
||||
class ScaleDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=1,
|
||||
out_channels=1,
|
||||
kernel_sizes=[15, 41, 5, 3],
|
||||
channels=128,
|
||||
max_downsample_channels=1024,
|
||||
max_groups=16,
|
||||
bias=True,
|
||||
downsample_scales=[2, 2, 4, 4, 1],
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
use_spectral_norm=False,
|
||||
):
|
||||
super(ScaleDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
|
||||
assert len(kernel_sizes) == 4
|
||||
for ks in kernel_sizes:
|
||||
assert ks % 2 == 1
|
||||
|
||||
self.convs = nn.ModuleList()
|
||||
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
channels,
|
||||
kernel_sizes[0],
|
||||
bias=bias,
|
||||
padding=(kernel_sizes[0] - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = channels
|
||||
out_chs = channels
|
||||
groups = 4
|
||||
|
||||
for downsample_scale in downsample_scales:
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[1],
|
||||
stride=downsample_scale,
|
||||
padding=(kernel_sizes[1] - 1) // 2,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
in_chs = out_chs
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
groups = min(groups * 4, max_groups)
|
||||
|
||||
out_chs = min(in_chs * 2, max_downsample_channels)
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv1d(
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=kernel_sizes[2],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[2] - 1) // 2,
|
||||
bias=bias,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv1d(
|
||||
out_chs,
|
||||
out_channels,
|
||||
kernel_size=kernel_sizes[3],
|
||||
stride=1,
|
||||
padding=(kernel_sizes[3] - 1) // 2,
|
||||
bias=bias,
|
||||
))
|
||||
|
||||
def forward(self, x):
|
||||
fmap = []
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = torch.flatten(x, 1, -1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
scales=3,
|
||||
downsample_pooling='DWT',
|
||||
# follow the official implementation setting
|
||||
downsample_pooling_params={
|
||||
'kernel_size': 4,
|
||||
'stride': 2,
|
||||
'padding': 2,
|
||||
},
|
||||
discriminator_params={
|
||||
'in_channels': 1,
|
||||
'out_channels': 1,
|
||||
'kernel_sizes': [15, 41, 5, 3],
|
||||
'channels': 128,
|
||||
'max_downsample_channels': 1024,
|
||||
'max_groups': 16,
|
||||
'bias': True,
|
||||
'downsample_scales': [2, 2, 4, 4, 1],
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
},
|
||||
follow_official_norm=False,
|
||||
):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.discriminators = torch.nn.ModuleList()
|
||||
|
||||
# add discriminators
|
||||
for i in range(scales):
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
if follow_official_norm:
|
||||
params['use_spectral_norm'] = True if i == 0 else False
|
||||
self.discriminators += [ScaleDiscriminator(**params)]
|
||||
|
||||
if downsample_pooling == 'DWT':
|
||||
self.meanpools = nn.ModuleList(
|
||||
[DWT1DForward(wave='db3', J=1),
|
||||
DWT1DForward(wave='db3', J=1)])
|
||||
self.aux_convs = nn.ModuleList([
|
||||
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
|
||||
weight_norm(nn.Conv1d(2, 1, 15, 1, padding=7)),
|
||||
])
|
||||
else:
|
||||
self.meanpools = nn.ModuleList(
|
||||
[nn.AvgPool1d(4, 2, padding=2),
|
||||
nn.AvgPool1d(4, 2, padding=2)])
|
||||
self.aux_convs = None
|
||||
|
||||
def forward(self, y):
|
||||
y_d_rs = []
|
||||
fmap_rs = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
if i != 0:
|
||||
if self.aux_convs is None:
|
||||
y = self.meanpools[i - 1](y)
|
||||
else:
|
||||
yl, yh = self.meanpools[i - 1](y)
|
||||
y = torch.cat([yl, yh[0]], dim=1)
|
||||
y = self.aux_convs[i - 1](y)
|
||||
y = F.leaky_relu(y, 0.1)
|
||||
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
|
||||
return y_d_rs, fmap_rs
|
||||
|
||||
|
||||
class SpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels=32,
|
||||
init_kernel=15,
|
||||
kernel_size=11,
|
||||
stride=2,
|
||||
use_spectral_norm=False,
|
||||
fft_size=1024,
|
||||
shift_size=120,
|
||||
win_length=600,
|
||||
window='hann_window',
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
):
|
||||
super(SpecDiscriminator, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
# fft_size // 2 + 1
|
||||
norm_f = weight_norm if not use_spectral_norm else spectral_norm
|
||||
final_kernel = 5
|
||||
post_conv_kernel = 3
|
||||
blocks = 3
|
||||
self.convs = nn.ModuleList()
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
fft_size // 2 + 1,
|
||||
channels,
|
||||
(init_kernel, 1),
|
||||
(1, 1),
|
||||
padding=(init_kernel - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
for i in range(blocks):
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
(kernel_size, 1),
|
||||
(stride, 1),
|
||||
padding=(kernel_size - 1) // 2,
|
||||
)),
|
||||
getattr(
|
||||
torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.convs.append(
|
||||
torch.nn.Sequential(
|
||||
norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
channels,
|
||||
(final_kernel, 1),
|
||||
(1, 1),
|
||||
padding=(final_kernel - 1) // 2,
|
||||
)),
|
||||
getattr(torch.nn,
|
||||
nonlinear_activation)(**nonlinear_activation_params),
|
||||
))
|
||||
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(
|
||||
channels,
|
||||
1,
|
||||
(post_conv_kernel, 1),
|
||||
(1, 1),
|
||||
padding=((post_conv_kernel - 1) // 2, 0),
|
||||
))
|
||||
self.register_buffer('window', getattr(torch, window)(win_length))
|
||||
|
||||
def forward(self, wav):
|
||||
with torch.no_grad():
|
||||
wav = torch.squeeze(wav, 1)
|
||||
x_mag = stft(wav, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
x = torch.transpose(x_mag, 2, 1).unsqueeze(-1)
|
||||
fmap = []
|
||||
for layer in self.convs:
|
||||
x = layer(x)
|
||||
fmap.append(x)
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x = x.squeeze(-1)
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiSpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
discriminator_params={
|
||||
'channels': 15,
|
||||
'init_kernel': 1,
|
||||
'kernel_sizes': 11,
|
||||
'stride': 2,
|
||||
'use_spectral_norm': False,
|
||||
'window': 'hann_window',
|
||||
'nonlinear_activation': 'LeakyReLU',
|
||||
'nonlinear_activation_params': {
|
||||
'negative_slope': 0.1
|
||||
},
|
||||
},
|
||||
):
|
||||
super(MultiSpecDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList()
|
||||
for fft_size, hop_size, win_length in zip(fft_sizes, hop_sizes,
|
||||
win_lengths):
|
||||
params = copy.deepcopy(discriminator_params)
|
||||
params['fft_size'] = fft_size
|
||||
params['shift_size'] = hop_size
|
||||
params['win_length'] = win_length
|
||||
self.discriminators += [SpecDiscriminator(**params)]
|
||||
|
||||
def forward(self, y):
|
||||
y_d = []
|
||||
fmap = []
|
||||
for i, d in enumerate(self.discriminators):
|
||||
x, x_map = d(y)
|
||||
y_d.append(x)
|
||||
fmap.append(x_map)
|
||||
|
||||
return y_d, fmap
|
||||
288
modelscope/models/audio/tts/kantts/models/hifigan/layers.py
Normal file
288
modelscope/models/audio/tts/kantts/models/hifigan/layers.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.distributions.normal import Normal
|
||||
from torch.distributions.uniform import Uniform
|
||||
from torch.nn.utils import remove_weight_norm, weight_norm
|
||||
|
||||
from modelscope.models.audio.tts.kantts.models.utils import init_weights
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class Conv1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
):
|
||||
super(Conv1d, self).__init__()
|
||||
self.conv1d = weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
))
|
||||
self.conv1d.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1d(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.conv1d)
|
||||
|
||||
|
||||
class CausalConv1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros',
|
||||
):
|
||||
super(CausalConv1d, self).__init__()
|
||||
self.pad = (kernel_size - 1) * dilation
|
||||
self.conv1d = weight_norm(
|
||||
nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
))
|
||||
self.conv1d.apply(init_weights)
|
||||
|
||||
def forward(self, x): # bdt
|
||||
x = F.pad(
|
||||
x, (self.pad, 0, 0, 0, 0, 0), 'constant'
|
||||
) # described starting from the last dimension and moving forward.
|
||||
# x = F.pad(x, (self.pad, self.pad, 0, 0, 0, 0), "constant")
|
||||
x = self.conv1d(x)[:, :, :x.size(2)]
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.conv1d)
|
||||
|
||||
|
||||
class ConvTranspose1d(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
):
|
||||
super(ConvTranspose1d, self).__init__()
|
||||
self.deconv = weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=padding,
|
||||
output_padding=0,
|
||||
))
|
||||
self.deconv.apply(init_weights)
|
||||
|
||||
def forward(self, x):
|
||||
return self.deconv(x)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.deconv)
|
||||
|
||||
|
||||
# FIXME: HACK to get shape right
|
||||
class CausalConvTranspose1d(torch.nn.Module):
|
||||
"""CausalConvTranspose1d module with customized initialization."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
):
|
||||
"""Initialize CausalConvTranspose1d module."""
|
||||
super(CausalConvTranspose1d, self).__init__()
|
||||
self.deconv = weight_norm(
|
||||
nn.ConvTranspose1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
output_padding=0,
|
||||
))
|
||||
self.stride = stride
|
||||
self.deconv.apply(init_weights)
|
||||
self.pad = kernel_size - stride
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T_in).
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T_out).
|
||||
"""
|
||||
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
|
||||
return self.deconv(x)[:, :, :-self.pad]
|
||||
# return self.deconv(x)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.deconv)
|
||||
|
||||
|
||||
class ResidualBlock(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
kernel_size=3,
|
||||
dilation=(1, 3, 5),
|
||||
nonlinear_activation='LeakyReLU',
|
||||
nonlinear_activation_params={'negative_slope': 0.1},
|
||||
causal=False,
|
||||
):
|
||||
super(ResidualBlock, self).__init__()
|
||||
assert kernel_size % 2 == 1, 'Kernal size must be odd number.'
|
||||
conv_cls = CausalConv1d if causal else Conv1d
|
||||
self.convs1 = nn.ModuleList([
|
||||
conv_cls(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[i],
|
||||
padding=get_padding(kernel_size, dilation[i]),
|
||||
) for i in range(len(dilation))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
conv_cls(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1),
|
||||
) for i in range(len(dilation))
|
||||
])
|
||||
|
||||
self.activation = getattr(
|
||||
torch.nn, nonlinear_activation)(**nonlinear_activation_params)
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = self.activation(x)
|
||||
xt = c1(xt)
|
||||
xt = self.activation(xt)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for layer in self.convs1:
|
||||
layer.remove_weight_norm()
|
||||
for layer in self.convs2:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
class SourceModule(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
nb_harmonics,
|
||||
upsample_ratio,
|
||||
sampling_rate,
|
||||
alpha=0.1,
|
||||
sigma=0.003):
|
||||
super(SourceModule, self).__init__()
|
||||
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.upsample_ratio = upsample_ratio
|
||||
self.sampling_rate = sampling_rate
|
||||
self.alpha = alpha
|
||||
self.sigma = sigma
|
||||
|
||||
self.ffn = nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(self.nb_harmonics + 1, 1, kernel_size=1, stride=1)),
|
||||
nn.Tanh(),
|
||||
)
|
||||
|
||||
def forward(self, pitch, uv):
|
||||
"""
|
||||
:param pitch: [B, 1, frame_len], Hz
|
||||
:param uv: [B, 1, frame_len] vuv flag
|
||||
:return: [B, 1, sample_len]
|
||||
"""
|
||||
with torch.no_grad():
|
||||
pitch_samples = F.interpolate(
|
||||
pitch, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
uv_samples = F.interpolate(
|
||||
uv, scale_factor=(self.upsample_ratio), mode='nearest')
|
||||
|
||||
F_mat = torch.zeros(
|
||||
(pitch_samples.size(0), self.nb_harmonics + 1,
|
||||
pitch_samples.size(-1))).to(pitch_samples.device)
|
||||
for i in range(self.nb_harmonics + 1):
|
||||
F_mat[:, i:i
|
||||
+ 1, :] = pitch_samples * (i + 1) / self.sampling_rate
|
||||
|
||||
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
||||
u_dist = Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(
|
||||
sample_shape=(pitch.size(0), self.nb_harmonics + 1,
|
||||
1)).to(F_mat.device)
|
||||
phase_vec[:, 0, :] = 0
|
||||
|
||||
n_dist = Normal(loc=0.0, scale=self.sigma)
|
||||
noise = n_dist.sample(
|
||||
sample_shape=(
|
||||
pitch_samples.size(0),
|
||||
self.nb_harmonics + 1,
|
||||
pitch_samples.size(-1),
|
||||
)).to(F_mat.device)
|
||||
|
||||
e_voice = self.alpha * torch.sin(theta_mat + phase_vec) + noise
|
||||
e_unvoice = self.alpha / 3 / self.sigma * noise
|
||||
|
||||
e = e_voice * uv_samples + e_unvoice * (1 - uv_samples)
|
||||
|
||||
return self.ffn(e)
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.ffn[0])
|
||||
133
modelscope/models/audio/tts/kantts/models/pqmf.py
Normal file
133
modelscope/models/audio/tts/kantts/models/pqmf.py
Normal file
@@ -0,0 +1,133 @@
|
||||
# The implementation is adopted from kan-bayashi's ParallelWaveGAN,
|
||||
# made publicly available under the MIT License at https://github.com/kan-bayashi/ParallelWaveGAN
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from scipy.signal import kaiser
|
||||
|
||||
|
||||
def design_prototype_filter(taps=62, cutoff_ratio=0.142, beta=9.0):
|
||||
"""Design prototype filter for PQMF.
|
||||
|
||||
This method is based on `A Kaiser window approach for the design of prototype
|
||||
filters of cosine modulated filterbanks`_.
|
||||
|
||||
Args:
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
Returns:
|
||||
ndarray: Impluse response of prototype filter (taps + 1,).
|
||||
|
||||
.. _`A Kaiser window approach for the design of prototype filters of cosine modulated filterbanks`:
|
||||
https://ieeexplore.ieee.org/abstract/document/681427
|
||||
|
||||
"""
|
||||
# check the arguments are valid
|
||||
assert taps % 2 == 0, 'The number of taps mush be even number.'
|
||||
assert 0.0 < cutoff_ratio < 1.0, 'Cutoff ratio must be > 0.0 and < 1.0.'
|
||||
|
||||
# make initial filter
|
||||
omega_c = np.pi * cutoff_ratio
|
||||
with np.errstate(invalid='ignore'):
|
||||
h_i = np.sin(omega_c * (np.arange(taps + 1) - 0.5 * taps)) / (
|
||||
np.pi * (np.arange(taps + 1) - 0.5 * taps))
|
||||
h_i[taps
|
||||
// 2] = np.cos(0) * cutoff_ratio # fix nan due to indeterminate form
|
||||
|
||||
# apply kaiser window
|
||||
w = kaiser(taps + 1, beta)
|
||||
h = h_i * w
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class PQMF(torch.nn.Module):
|
||||
"""PQMF module.
|
||||
|
||||
This module is based on `Near-perfect-reconstruction pseudo-QMF banks`_.
|
||||
|
||||
.. _`Near-perfect-reconstruction pseudo-QMF banks`:
|
||||
https://ieeexplore.ieee.org/document/258122
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, subbands=4, taps=62, cutoff_ratio=0.142, beta=9.0):
|
||||
"""Initilize PQMF module.
|
||||
|
||||
The cutoff_ratio and beta parameters are optimized for #subbands = 4.
|
||||
See dicussion in https://github.com/kan-bayashi/ParallelWaveGAN/issues/195.
|
||||
|
||||
Args:
|
||||
subbands (int): The number of subbands.
|
||||
taps (int): The number of filter taps.
|
||||
cutoff_ratio (float): Cut-off frequency ratio.
|
||||
beta (float): Beta coefficient for kaiser window.
|
||||
|
||||
"""
|
||||
super(PQMF, self).__init__()
|
||||
|
||||
# build analysis & synthesis filter coefficients
|
||||
h_proto = design_prototype_filter(taps, cutoff_ratio, beta)
|
||||
h_analysis = np.zeros((subbands, len(h_proto)))
|
||||
h_synthesis = np.zeros((subbands, len(h_proto)))
|
||||
for k in range(subbands):
|
||||
h_analysis[k] = (
|
||||
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
|
||||
(np.pi / (2 * subbands)) * # noqa W504
|
||||
(np.arange(taps + 1) - (taps / 2))
|
||||
+ (-1)**k * np.pi / 4))
|
||||
h_synthesis[k] = (
|
||||
2 * h_proto * np.cos((2 * k + 1) * # noqa W504
|
||||
(np.pi / (2 * subbands)) * # noqa W504
|
||||
(np.arange(taps + 1) - (taps / 2))
|
||||
- (-1)**k * np.pi / 4))
|
||||
|
||||
# convert to tensor
|
||||
analysis_filter = torch.from_numpy(h_analysis).float().unsqueeze(1)
|
||||
synthesis_filter = torch.from_numpy(h_synthesis).float().unsqueeze(0)
|
||||
|
||||
# register coefficients as beffer
|
||||
self.register_buffer('analysis_filter', analysis_filter)
|
||||
self.register_buffer('synthesis_filter', synthesis_filter)
|
||||
|
||||
# filter for downsampling & upsampling
|
||||
updown_filter = torch.zeros((subbands, subbands, subbands)).float()
|
||||
for k in range(subbands):
|
||||
updown_filter[k, k, 0] = 1.0
|
||||
self.register_buffer('updown_filter', updown_filter)
|
||||
self.subbands = subbands
|
||||
|
||||
# keep padding info
|
||||
self.pad_fn = torch.nn.ConstantPad1d(taps // 2, 0.0)
|
||||
|
||||
def analysis(self, x):
|
||||
"""Analysis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, subbands, T // subbands).
|
||||
|
||||
"""
|
||||
x = F.conv1d(self.pad_fn(x), self.analysis_filter)
|
||||
return F.conv1d(x, self.updown_filter, stride=self.subbands)
|
||||
|
||||
def synthesis(self, x):
|
||||
"""Synthesis with PQMF.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, subbands, T // subbands).
|
||||
|
||||
Returns:
|
||||
Tensor: Output tensor (B, 1, T).
|
||||
|
||||
"""
|
||||
# NOTE(kan-bayashi): Power will be dreased so here multipy by # subbands.
|
||||
# Not sure this is the correct way, it is better to check again.
|
||||
x = F.conv_transpose1d(
|
||||
x, self.updown_filter * self.subbands, stride=self.subbands)
|
||||
return F.conv1d(self.pad_fn(x), self.synthesis_filter)
|
||||
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -42,7 +41,7 @@ class Prenet(nn.Module):
|
||||
self.fcs.append(nn.ReLU())
|
||||
self.fcs.append(nn.Dropout(0.5))
|
||||
|
||||
if (out_units):
|
||||
if out_units:
|
||||
self.fcs.append(nn.Linear(prenet_units[-1], out_units))
|
||||
|
||||
def forward(self, input):
|
||||
@@ -105,7 +104,7 @@ class MultiHeadSelfAttention(nn.Module):
|
||||
-1)) # b x l x (n*d)
|
||||
|
||||
output = self.dropout(self.fc(output))
|
||||
if (output.size(-1) == residual.size(-1)):
|
||||
if output.size(-1) == residual.size(-1):
|
||||
output = output + residual
|
||||
|
||||
return output, attn
|
||||
@@ -162,16 +161,18 @@ class PositionwiseConvFeedForward(nn.Module):
|
||||
class FFTBlock(nn.Module):
|
||||
"""FFT Block"""
|
||||
|
||||
def __init__(self,
|
||||
d_in,
|
||||
d_model,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0):
|
||||
def __init__(
|
||||
self,
|
||||
d_in,
|
||||
d_model,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0,
|
||||
):
|
||||
super(FFTBlock, self).__init__()
|
||||
self.slf_attn = MultiHeadSelfAttention(
|
||||
n_head,
|
||||
@@ -239,7 +240,7 @@ class MultiHeadPNCAAttention(nn.Module):
|
||||
x_k = x_k.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
|
||||
x_v = x_v.permute(2, 0, 1, 3).contiguous().view(-1, len_x, d_head)
|
||||
|
||||
if (self.x_state_size):
|
||||
if self.x_state_size:
|
||||
self.x_k = torch.cat([self.x_k, x_k], dim=1)
|
||||
self.x_v = torch.cat([self.x_v, x_v], dim=1)
|
||||
else:
|
||||
@@ -251,7 +252,7 @@ class MultiHeadPNCAAttention(nn.Module):
|
||||
return x_q, x_k, x_v
|
||||
|
||||
def update_h_state(self, h):
|
||||
if (self.h_state_size == h.size(1)):
|
||||
if self.h_state_size == h.size(1):
|
||||
return None, None
|
||||
|
||||
d_head, n_head = self.d_head, self.n_head
|
||||
@@ -323,16 +324,18 @@ class MultiHeadPNCAAttention(nn.Module):
|
||||
class PNCABlock(nn.Module):
|
||||
"""PNCA Block"""
|
||||
|
||||
def __init__(self,
|
||||
d_model,
|
||||
d_mem,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
d_mem,
|
||||
n_head,
|
||||
d_head,
|
||||
d_inner,
|
||||
kernel_size,
|
||||
dropout,
|
||||
dropout_attn=0.0,
|
||||
dropout_relu=0.0,
|
||||
):
|
||||
super(PNCABlock, self).__init__()
|
||||
self.pnca_attn = MultiHeadPNCAAttention(
|
||||
n_head,
|
||||
@@ -1,10 +1,9 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .base import Prenet
|
||||
from . import Prenet
|
||||
from .fsmn import FsmnEncoderV2
|
||||
|
||||
|
||||
@@ -22,8 +21,8 @@ class LengthRegulator(nn.Module):
|
||||
reps_cumsum = torch.cumsum(
|
||||
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
||||
range_ = torch.arange(max_len).to(inputs.device)[None, :, None]
|
||||
mult = ((reps_cumsum[:, :, :-1] <= range_)
|
||||
& (reps_cumsum[:, :, 1:] > range_)) # yapf:disable
|
||||
mult = (reps_cumsum[:, :, :-1] <= range_) & (
|
||||
reps_cumsum[:, :, 1:] > range_)
|
||||
mult = mult.float()
|
||||
out = torch.matmul(mult, inputs)
|
||||
|
||||
@@ -32,7 +31,7 @@ class LengthRegulator(nn.Module):
|
||||
|
||||
seq_len = out.size(1)
|
||||
padding = self.r - int(seq_len) % self.r
|
||||
if (padding < self.r):
|
||||
if padding < self.r:
|
||||
out = F.pad(
|
||||
out.transpose(1, 2), (0, padding, 0, 0, 0, 0), value=0.0)
|
||||
out = out.transpose(1, 2)
|
||||
@@ -51,7 +50,8 @@ class VarRnnARPredictor(nn.Module):
|
||||
rnn_units,
|
||||
num_layers=2,
|
||||
batch_first=True,
|
||||
bidirectional=False)
|
||||
bidirectional=False,
|
||||
)
|
||||
self.fc = nn.Linear(rnn_units, 1)
|
||||
|
||||
def forward(self, inputs, cond, h=None, masks=None):
|
||||
@@ -89,19 +89,35 @@ class VarRnnARPredictor(nn.Module):
|
||||
|
||||
class VarFsmnRnnNARPredictor(nn.Module):
|
||||
|
||||
def __init__(self, in_dim, filter_size, fsmn_num_layers, num_memory_units,
|
||||
ffn_inner_dim, dropout, shift, lstm_units):
|
||||
def __init__(
|
||||
self,
|
||||
in_dim,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout,
|
||||
shift,
|
||||
lstm_units,
|
||||
):
|
||||
super(VarFsmnRnnNARPredictor, self).__init__()
|
||||
|
||||
self.fsmn = FsmnEncoderV2(filter_size, fsmn_num_layers, in_dim,
|
||||
num_memory_units, ffn_inner_dim, dropout,
|
||||
shift)
|
||||
self.fsmn = FsmnEncoderV2(
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
in_dim,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout,
|
||||
shift,
|
||||
)
|
||||
self.blstm = nn.LSTM(
|
||||
num_memory_units,
|
||||
lstm_units,
|
||||
num_layers=1,
|
||||
batch_first=True,
|
||||
bidirectional=True)
|
||||
bidirectional=True,
|
||||
)
|
||||
self.fc = nn.Linear(2 * lstm_units, 1)
|
||||
|
||||
def forward(self, inputs, masks=None):
|
||||
@@ -0,0 +1,73 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numba as nb
|
||||
import numpy as np
|
||||
|
||||
|
||||
@nb.jit(nopython=True)
|
||||
def mas(attn_map, width=1):
|
||||
# assumes mel x text
|
||||
opt = np.zeros_like(attn_map)
|
||||
attn_map = np.log(attn_map)
|
||||
attn_map[0, 1:] = -np.inf
|
||||
log_p = np.zeros_like(attn_map)
|
||||
log_p[0, :] = attn_map[0, :]
|
||||
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
||||
for i in range(1, attn_map.shape[0]):
|
||||
for j in range(attn_map.shape[1]): # for each text dim
|
||||
prev_j = np.arange(max(0, j - width), j + 1)
|
||||
prev_log = np.array(
|
||||
[log_p[i - 1, prev_idx] for prev_idx in prev_j])
|
||||
|
||||
ind = np.argmax(prev_log)
|
||||
log_p[i, j] = attn_map[i, j] + prev_log[ind]
|
||||
prev_ind[i, j] = prev_j[ind]
|
||||
|
||||
# now backtrack
|
||||
curr_text_idx = attn_map.shape[1] - 1
|
||||
for i in range(attn_map.shape[0] - 1, -1, -1):
|
||||
opt[i, curr_text_idx] = 1
|
||||
curr_text_idx = prev_ind[i, curr_text_idx]
|
||||
opt[0, curr_text_idx] = 1
|
||||
return opt
|
||||
|
||||
|
||||
@nb.jit(nopython=True)
|
||||
def mas_width1(attn_map):
|
||||
"""mas with hardcoded width=1"""
|
||||
# assumes mel x text
|
||||
opt = np.zeros_like(attn_map)
|
||||
attn_map = np.log(attn_map)
|
||||
attn_map[0, 1:] = -np.inf
|
||||
log_p = np.zeros_like(attn_map)
|
||||
log_p[0, :] = attn_map[0, :]
|
||||
prev_ind = np.zeros_like(attn_map, dtype=np.int64)
|
||||
for i in range(1, attn_map.shape[0]):
|
||||
for j in range(attn_map.shape[1]): # for each text dim
|
||||
prev_log = log_p[i - 1, j]
|
||||
prev_j = j
|
||||
|
||||
if j - 1 >= 0 and log_p[i - 1, j - 1] >= log_p[i - 1, j]:
|
||||
prev_log = log_p[i - 1, j - 1]
|
||||
prev_j = j - 1
|
||||
|
||||
log_p[i, j] = attn_map[i, j] + prev_log
|
||||
prev_ind[i, j] = prev_j
|
||||
|
||||
# now backtrack
|
||||
curr_text_idx = attn_map.shape[1] - 1
|
||||
for i in range(attn_map.shape[0] - 1, -1, -1):
|
||||
opt[i, curr_text_idx] = 1
|
||||
curr_text_idx = prev_ind[i, curr_text_idx]
|
||||
opt[0, curr_text_idx] = 1
|
||||
return opt
|
||||
|
||||
|
||||
@nb.jit(nopython=True, parallel=True)
|
||||
def b_mas(b_attn_map, in_lens, out_lens, width=1):
|
||||
assert width == 1
|
||||
attn_out = np.zeros_like(b_attn_map)
|
||||
|
||||
for b in nb.prange(b_attn_map.shape[0]):
|
||||
out = mas_width1(b_attn_map[b, 0, :out_lens[b], :in_lens[b]])
|
||||
attn_out[b, 0, :out_lens[b], :in_lens[b]] = out
|
||||
return attn_out
|
||||
131
modelscope/models/audio/tts/kantts/models/sambert/attention.py
Normal file
131
modelscope/models/audio/tts/kantts/models/sambert/attention.py
Normal file
@@ -0,0 +1,131 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class ConvNorm(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=None,
|
||||
dilation=1,
|
||||
bias=True,
|
||||
w_init_gain='linear',
|
||||
):
|
||||
super(ConvNorm, self).__init__()
|
||||
if padding is None:
|
||||
assert kernel_size % 2 == 1
|
||||
padding = int(dilation * (kernel_size - 1) / 2)
|
||||
|
||||
self.conv = torch.nn.Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
torch.nn.init.xavier_uniform_(
|
||||
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
||||
|
||||
def forward(self, signal):
|
||||
conv_signal = self.conv(signal)
|
||||
return conv_signal
|
||||
|
||||
|
||||
class ConvAttention(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_mel_channels=80,
|
||||
n_text_channels=512,
|
||||
n_att_channels=80,
|
||||
temperature=1.0,
|
||||
use_query_proj=True,
|
||||
):
|
||||
super(ConvAttention, self).__init__()
|
||||
self.temperature = temperature
|
||||
self.att_scaling_factor = np.sqrt(n_att_channels)
|
||||
self.softmax = torch.nn.Softmax(dim=3)
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.attn_proj = torch.nn.Conv2d(n_att_channels, 1, kernel_size=1)
|
||||
self.use_query_proj = bool(use_query_proj)
|
||||
|
||||
self.key_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
n_text_channels,
|
||||
n_text_channels * 2,
|
||||
kernel_size=3,
|
||||
bias=True,
|
||||
w_init_gain='relu',
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(
|
||||
n_text_channels * 2, n_att_channels, kernel_size=1, bias=True),
|
||||
)
|
||||
|
||||
self.query_proj = nn.Sequential(
|
||||
ConvNorm(
|
||||
n_mel_channels,
|
||||
n_mel_channels * 2,
|
||||
kernel_size=3,
|
||||
bias=True,
|
||||
w_init_gain='relu',
|
||||
),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(
|
||||
n_mel_channels * 2, n_mel_channels, kernel_size=1, bias=True),
|
||||
torch.nn.ReLU(),
|
||||
ConvNorm(n_mel_channels, n_att_channels, kernel_size=1, bias=True),
|
||||
)
|
||||
|
||||
def forward(self, queries, keys, mask=None, attn_prior=None):
|
||||
"""Attention mechanism for flowtron parallel
|
||||
Unlike in Flowtron, we have no restrictions such as causality etc,
|
||||
since we only need this during training.
|
||||
|
||||
Args:
|
||||
queries (torch.tensor): B x C x T1 tensor
|
||||
(probably going to be mel data)
|
||||
keys (torch.tensor): B x C2 x T2 tensor (text data)
|
||||
mask (torch.tensor): uint8 binary mask for variable length entries
|
||||
(should be in the T2 domain)
|
||||
Output:
|
||||
attn (torch.tensor): B x 1 x T1 x T2 attention mask.
|
||||
Final dim T2 should sum to 1
|
||||
"""
|
||||
keys_enc = self.key_proj(keys) # B x n_attn_dims x T2
|
||||
|
||||
# Beware can only do this since query_dim = attn_dim = n_mel_channels
|
||||
if self.use_query_proj:
|
||||
queries_enc = self.query_proj(queries)
|
||||
else:
|
||||
queries_enc = queries
|
||||
|
||||
# different ways of computing attn,
|
||||
# one is isotopic gaussians (per phoneme)
|
||||
# Simplistic Gaussian Isotopic Attention
|
||||
|
||||
# B x n_attn_dims x T1 x T2
|
||||
attn = (queries_enc[:, :, :, None] - keys_enc[:, :, None])**2
|
||||
# compute log likelihood from a gaussian
|
||||
attn = -0.0005 * attn.sum(1, keepdim=True)
|
||||
if attn_prior is not None:
|
||||
attn = self.log_softmax(attn) + torch.log(attn_prior[:, None]
|
||||
+ 1e-8)
|
||||
|
||||
attn_logprob = attn.clone()
|
||||
|
||||
if mask is not None:
|
||||
attn.data.masked_fill_(
|
||||
mask.unsqueeze(1).unsqueeze(1), -float('inf'))
|
||||
|
||||
attn = self.softmax(attn) # Softmax along T2
|
||||
return attn, attn_logprob
|
||||
@@ -1,8 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
"""
|
||||
FSMN Pytorch Version
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -27,7 +23,8 @@ class FeedForwardNet(nn.Module):
|
||||
d_out,
|
||||
kernel_size=kernel_size[1],
|
||||
padding=(kernel_size[1] - 1) // 2,
|
||||
bias=False)
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
@@ -63,8 +60,10 @@ class MemoryBlockV2(nn.Module):
|
||||
|
||||
x = F.pad(
|
||||
input, (0, 0, self.lp, self.rp, 0, 0), mode='constant', value=0.0)
|
||||
output = self.conv_dw(x.contiguous().transpose(
|
||||
1, 2)).contiguous().transpose(1, 2)
|
||||
output = (
|
||||
self.conv_dw(x.contiguous().transpose(1,
|
||||
2)).contiguous().transpose(
|
||||
1, 2))
|
||||
output += input
|
||||
output = self.dropout(output)
|
||||
|
||||
@@ -76,14 +75,16 @@ class MemoryBlockV2(nn.Module):
|
||||
|
||||
class FsmnEncoderV2(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
input_dim,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout=0.0,
|
||||
shift=0):
|
||||
def __init__(
|
||||
self,
|
||||
filter_size,
|
||||
fsmn_num_layers,
|
||||
input_dim,
|
||||
num_memory_units,
|
||||
ffn_inner_dim,
|
||||
dropout=0.0,
|
||||
shift=0,
|
||||
):
|
||||
super(FsmnEncoderV2, self).__init__()
|
||||
|
||||
self.filter_size = filter_size
|
||||
@@ -119,7 +120,7 @@ class FsmnEncoderV2(nn.Module):
|
||||
context = ffn(x)
|
||||
memory = memory_block(context, mask)
|
||||
memory = F.dropout(memory, self.dropout, self.training)
|
||||
if (memory.size(-1) == x.size(-1)):
|
||||
if memory.size(-1) == x.size(-1):
|
||||
memory += x
|
||||
x = memory
|
||||
|
||||
1032
modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py
Normal file
1032
modelscope/models/audio/tts/kantts/models/sambert/kantts_sambert.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,5 +1,4 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -15,14 +14,16 @@ class SinusoidalPositionEncoder(nn.Module):
|
||||
self.depth = depth
|
||||
self.position_enc = nn.Parameter(
|
||||
self.get_sinusoid_encoding_table(max_len, depth).unsqueeze(0),
|
||||
requires_grad=False)
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
bz_in, len_in, _ = input.size()
|
||||
if len_in > self.max_len:
|
||||
self.max_len = len_in
|
||||
self.position_enc.data = self.get_sinusoid_encoding_table(
|
||||
self.max_len, self.depth).unsqueeze(0).to(input.device)
|
||||
self.position_enc.data = (
|
||||
self.get_sinusoid_encoding_table(
|
||||
self.max_len, self.depth).unsqueeze(0).to(input.device))
|
||||
|
||||
output = input + self.position_enc[:, :len_in, :].expand(bz_in, -1, -1)
|
||||
|
||||
@@ -74,8 +75,8 @@ class DurSinusoidalPositionEncoder(nn.Module):
|
||||
reps_cumsum = torch.cumsum(
|
||||
F.pad(reps.float(), (1, 0, 0, 0), value=0.0), dim=1)[:, None, :]
|
||||
range_ = torch.arange(max_len).to(durations.device)[None, :, None]
|
||||
mult = ((reps_cumsum[:, :, :-1] <= range_)
|
||||
& (reps_cumsum[:, :, 1:] > range_)) # yapf:disable
|
||||
mult = (reps_cumsum[:, :, :-1] <= range_) & (
|
||||
reps_cumsum[:, :, 1:] > range_)
|
||||
mult = mult.float()
|
||||
offsets = torch.matmul(mult,
|
||||
reps_cumsum[:,
|
||||
@@ -88,7 +89,7 @@ class DurSinusoidalPositionEncoder(nn.Module):
|
||||
|
||||
seq_len = dur_pos.size(1)
|
||||
padding = self.outputs_per_step - int(seq_len) % self.outputs_per_step
|
||||
if (padding < self.outputs_per_step):
|
||||
if padding < self.outputs_per_step:
|
||||
dur_pos = F.pad(dur_pos, (0, padding, 0, 0), value=0.0)
|
||||
|
||||
position_embedding = dur_pos[:, :, None] / self.inv_timescales[None,
|
||||
26
modelscope/models/audio/tts/kantts/models/utils.py
Normal file
26
modelscope/models/audio/tts/kantts/models/utils.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
batch_size = lengths.shape[0]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
|
||||
ids = (
|
||||
torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
|
||||
-1).to(lengths.device))
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
|
||||
return mask
|
||||
@@ -0,0 +1,774 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from glob import glob
|
||||
|
||||
import numpy as np
|
||||
import yaml
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core.dsp import (load_wav, melspectrogram, save_wav, trim_silence,
|
||||
trim_silence_with_interval)
|
||||
from .core.utils import (align_length, average_by_duration, compute_mean,
|
||||
compute_std, encode_16bits, f0_norm_mean_std,
|
||||
get_energy, get_pitch, norm_mean_std,
|
||||
parse_interval_file, volume_normalize)
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
default_audio_config = {
|
||||
# Preprocess
|
||||
'wav_normalize': True,
|
||||
'trim_silence': True,
|
||||
'trim_silence_threshold_db': 60,
|
||||
'preemphasize': False,
|
||||
# Feature extraction
|
||||
'sampling_rate': 24000,
|
||||
'hop_length': 240,
|
||||
'win_length': 1024,
|
||||
'n_mels': 80,
|
||||
'n_fft': 1024,
|
||||
'fmin': 50.0,
|
||||
'fmax': 7600.0,
|
||||
'min_level_db': -100,
|
||||
'ref_level_db': 20,
|
||||
'phone_level_feature': True,
|
||||
'num_workers': 16,
|
||||
# Normalization
|
||||
'norm_type': 'mean_std', # 'mean_std', 'global norm'
|
||||
'max_norm': 1.0,
|
||||
'symmetric': False,
|
||||
}
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
|
||||
def __init__(self, config=None):
|
||||
if not isinstance(config, dict):
|
||||
logging.warning(
|
||||
'[AudioProcessor] config is not a dict, fall into default config.'
|
||||
)
|
||||
self.config = default_audio_config
|
||||
else:
|
||||
self.config = config
|
||||
|
||||
for key in self.config:
|
||||
setattr(self, key, self.config[key])
|
||||
|
||||
self.min_wav_length = int(self.config['sampling_rate'] * 0.5)
|
||||
|
||||
self.badcase_list = []
|
||||
self.pcm_dict = {}
|
||||
self.mel_dict = {}
|
||||
self.f0_dict = {}
|
||||
self.uv_dict = {}
|
||||
self.nccf_dict = {}
|
||||
self.f0uv_dict = {}
|
||||
self.energy_dict = {}
|
||||
self.dur_dict = {}
|
||||
logging.info('[AudioProcessor] Initialize AudioProcessor.')
|
||||
logging.info('[AudioProcessor] config params:')
|
||||
for key in self.config:
|
||||
logging.info('[AudioProcessor] %s: %s', key, self.config[key])
|
||||
|
||||
def calibrate_SyllableDuration(self, raw_dur_dir, raw_metafile,
|
||||
out_cali_duration_dir):
|
||||
with open(raw_metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
|
||||
output_dur_dir = out_cali_duration_dir
|
||||
os.makedirs(output_dur_dir, exist_ok=True)
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
index, symbols = line.split('\t')
|
||||
symbols = [
|
||||
symbol.strip('{').strip('}').split('$')[0]
|
||||
for symbol in symbols.strip().split(' ')
|
||||
]
|
||||
dur_file = os.path.join(raw_dur_dir, index + '.npy')
|
||||
phone_file = os.path.join(raw_dur_dir, index + '.phone')
|
||||
if not os.path.exists(dur_file) or not os.path.exists(phone_file):
|
||||
logging.warning(
|
||||
'[AudioProcessor] dur file or phone file not exists: %s',
|
||||
index)
|
||||
continue
|
||||
with open(phone_file, 'r') as f:
|
||||
phones = f.readlines()
|
||||
dur = np.load(dur_file)
|
||||
cali_duration = []
|
||||
|
||||
dur_idx = 0
|
||||
syll_idx = 0
|
||||
|
||||
while dur_idx < len(dur) and syll_idx < len(symbols):
|
||||
if phones[dur_idx].strip() == 'sil':
|
||||
dur_idx += 1
|
||||
continue
|
||||
|
||||
if phones[dur_idx].strip(
|
||||
) == 'sp' and symbols[syll_idx][0] != '#':
|
||||
dur_idx += 1
|
||||
continue
|
||||
|
||||
if symbols[syll_idx] in ['ga', 'go', 'ge']:
|
||||
cali_duration.append(0)
|
||||
syll_idx += 1
|
||||
# print("NONE", symbols[syll_idx], 0)
|
||||
continue
|
||||
|
||||
if symbols[syll_idx][0] == '#':
|
||||
if phones[dur_idx].strip() != 'sp':
|
||||
cali_duration.append(0)
|
||||
# print("NONE", symbols[syll_idx], 0)
|
||||
syll_idx += 1
|
||||
continue
|
||||
else:
|
||||
cali_duration.append(dur[dur_idx])
|
||||
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
|
||||
dur_idx += 1
|
||||
syll_idx += 1
|
||||
continue
|
||||
# A corresponding phone is found
|
||||
cali_duration.append(dur[dur_idx])
|
||||
# print(phones[dur_idx].strip(), symbols[syll_idx], dur[dur_idx])
|
||||
dur_idx += 1
|
||||
syll_idx += 1
|
||||
# Add #4 phone duration
|
||||
cali_duration.append(0)
|
||||
if len(cali_duration) != len(symbols):
|
||||
logging.error('[Duration Calibrating] Syllable duration {}\
|
||||
is not equal to the number of symbols {}, index: {}'.
|
||||
format(len(cali_duration), len(symbols), index))
|
||||
continue
|
||||
|
||||
# Align with mel frames
|
||||
durs = np.array(cali_duration)
|
||||
if len(self.mel_dict) > 0:
|
||||
pair_mel = self.mel_dict.get(index, None)
|
||||
if pair_mel is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Interval file %s has no corresponding mel',
|
||||
index,
|
||||
)
|
||||
continue
|
||||
mel_frames = pair_mel.shape[0]
|
||||
dur_frames = np.sum(durs)
|
||||
if np.sum(durs) > mel_frames:
|
||||
durs[-2] -= dur_frames - mel_frames
|
||||
elif np.sum(durs) < mel_frames:
|
||||
durs[-2] += mel_frames - np.sum(durs)
|
||||
|
||||
if durs[-2] < 0:
|
||||
logging.error(
|
||||
'[AudioProcessor] Duration calibrating failed for %s, mismatch frames %s',
|
||||
index,
|
||||
durs[-2],
|
||||
)
|
||||
self.badcase_list.append(index)
|
||||
continue
|
||||
|
||||
self.dur_dict[index] = durs
|
||||
|
||||
np.save(
|
||||
os.path.join(output_dur_dir, index + '.npy'),
|
||||
self.dur_dict[index])
|
||||
|
||||
def amp_normalize(self, src_wav_dir, out_wav_dir):
|
||||
if self.wav_normalize:
|
||||
logging.info('[AudioProcessor] Amplitude normalization started')
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
res = volume_normalize(src_wav_dir, out_wav_dir)
|
||||
logging.info('[AudioProcessor] Amplitude normalization finished')
|
||||
return res
|
||||
else:
|
||||
logging.info('[AudioProcessor] No amplitude normalization')
|
||||
os.symlink(src_wav_dir, out_wav_dir, target_is_directory=True)
|
||||
return True
|
||||
|
||||
def get_pcm_dict(self, src_wav_dir):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
if len(self.pcm_dict) > 0:
|
||||
return self.pcm_dict
|
||||
|
||||
logging.info('[AudioProcessor] Start to load pcm from %s', src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_path in wav_list:
|
||||
future = executor.submit(load_wav, wav_path,
|
||||
self.sampling_rate)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
wav_name = os.path.splitext(os.path.basename(wav_path))[0]
|
||||
futures.append((future, wav_name))
|
||||
for future, wav_name in futures:
|
||||
pcm = future.result()
|
||||
if len(pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_name)
|
||||
self.badcase_list.append(wav_name)
|
||||
continue
|
||||
self.pcm_dict[wav_name] = pcm
|
||||
|
||||
return self.pcm_dict
|
||||
|
||||
def trim_silence_wav(self, src_wav_dir, out_wav_dir=None):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
logging.info('[AudioProcessor] Trim silence started')
|
||||
if out_wav_dir is None:
|
||||
out_wav_dir = src_wav_dir
|
||||
else:
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
trim_silence,
|
||||
pcm_data,
|
||||
self.trim_silence_threshold_db,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
for future, wav_basename in tqdm(futures):
|
||||
pcm = future.result()
|
||||
if len(pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
self.pcm_dict.pop(wav_basename)
|
||||
continue
|
||||
self.pcm_dict[wav_basename] = pcm
|
||||
save_wav(
|
||||
self.pcm_dict[wav_basename],
|
||||
os.path.join(out_wav_dir, wav_basename + '.wav'),
|
||||
self.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Trim silence finished')
|
||||
return True
|
||||
|
||||
def trim_silence_wav_with_interval(self,
|
||||
src_wav_dir,
|
||||
dur_dir,
|
||||
out_wav_dir=None):
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
logging.info('[AudioProcessor] Trim silence with interval started')
|
||||
if out_wav_dir is None:
|
||||
out_wav_dir = src_wav_dir
|
||||
else:
|
||||
os.makedirs(out_wav_dir, exist_ok=True)
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
trim_silence_with_interval,
|
||||
pcm_data,
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
for future, wav_basename in tqdm(futures):
|
||||
trimed_pcm = future.result()
|
||||
if trimed_pcm is None:
|
||||
continue
|
||||
if len(trimed_pcm) < self.min_wav_length:
|
||||
logging.warning('[AudioProcessor] %s is too short, skip',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
self.pcm_dict.pop(wav_basename)
|
||||
continue
|
||||
self.pcm_dict[wav_basename] = trimed_pcm
|
||||
save_wav(
|
||||
self.pcm_dict[wav_basename],
|
||||
os.path.join(out_wav_dir, wav_basename + '.wav'),
|
||||
self.sampling_rate,
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Trim silence finished')
|
||||
return True
|
||||
|
||||
def mel_extract(self, src_wav_dir, out_feature_dir):
|
||||
os.makedirs(out_feature_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
|
||||
logging.info('[AudioProcessor] Melspec extraction started')
|
||||
|
||||
# Get global normed mel spec
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
melspectrogram,
|
||||
pcm_data,
|
||||
self.sampling_rate,
|
||||
self.n_fft,
|
||||
self.hop_length,
|
||||
self.win_length,
|
||||
self.n_mels,
|
||||
self.max_norm,
|
||||
self.min_level_db,
|
||||
self.ref_level_db,
|
||||
self.fmin,
|
||||
self.fmax,
|
||||
self.symmetric,
|
||||
self.preemphasize,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Melspec extraction failed for %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
melspec = result
|
||||
self.mel_dict[wav_basename] = melspec
|
||||
|
||||
logging.info('[AudioProcessor] Melspec extraction finished')
|
||||
|
||||
# FIXME: is this step necessary?
|
||||
# Do mean std norm on global-normed melspec
|
||||
logging.info('Melspec statistic proceeding...')
|
||||
mel_mean = compute_mean(list(self.mel_dict.values()), dims=self.n_mels)
|
||||
mel_std = compute_std(
|
||||
list(self.mel_dict.values()), mel_mean, dims=self.n_mels)
|
||||
logging.info('Melspec statistic done')
|
||||
np.savetxt(
|
||||
os.path.join(out_feature_dir, 'mel_mean.txt'),
|
||||
mel_mean,
|
||||
fmt='%.6f')
|
||||
np.savetxt(
|
||||
os.path.join(out_feature_dir, 'mel_std.txt'), mel_std, fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] melspec mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_feature_dir, 'mel_mean.txt'),
|
||||
os.path.join(out_feature_dir, 'mel_std.txt'),
|
||||
))
|
||||
|
||||
logging.info('[AudioProcessor] Melspec mean std norm is proceeding...')
|
||||
for wav_basename in self.mel_dict:
|
||||
melspec = self.mel_dict[wav_basename]
|
||||
norm_melspec = norm_mean_std(melspec, mel_mean, mel_std)
|
||||
np.save(
|
||||
os.path.join(out_feature_dir, wav_basename + '.npy'),
|
||||
norm_melspec)
|
||||
|
||||
logging.info('[AudioProcessor] Melspec normalization finished')
|
||||
logging.info('[AudioProcessor] Normed Melspec saved to %s',
|
||||
out_feature_dir)
|
||||
|
||||
return True
|
||||
|
||||
def duration_generate(self, src_interval_dir, out_feature_dir):
|
||||
os.makedirs(out_feature_dir, exist_ok=True)
|
||||
interval_list = glob(os.path.join(src_interval_dir, '*.interval'))
|
||||
|
||||
logging.info('[AudioProcessor] Duration generation started')
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(interval_list)) as progress:
|
||||
futures = []
|
||||
for interval_file_path in interval_list:
|
||||
future = executor.submit(
|
||||
parse_interval_file,
|
||||
interval_file_path,
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future,
|
||||
os.path.splitext(
|
||||
os.path.basename(interval_file_path))[0]))
|
||||
|
||||
logging.info(
|
||||
'[AudioProcessor] Duration align with mel is proceeding...')
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Duration generate failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
durs, phone_list = result
|
||||
# Algin length with melspec
|
||||
if len(self.mel_dict) > 0:
|
||||
pair_mel = self.mel_dict.get(wav_basename, None)
|
||||
if pair_mel is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Interval file %s has no corresponding mel',
|
||||
wav_basename,
|
||||
)
|
||||
continue
|
||||
mel_frames = pair_mel.shape[0]
|
||||
dur_frames = np.sum(durs)
|
||||
if np.sum(durs) > mel_frames:
|
||||
durs[-1] -= dur_frames - mel_frames
|
||||
elif np.sum(durs) < mel_frames:
|
||||
durs[-1] += mel_frames - np.sum(durs)
|
||||
|
||||
if durs[-1] < 0:
|
||||
logging.error(
|
||||
'[AudioProcessor] Duration align failed for %s, mismatch frames %s',
|
||||
wav_basename,
|
||||
durs[-1],
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
|
||||
self.dur_dict[wav_basename] = durs
|
||||
|
||||
np.save(
|
||||
os.path.join(out_feature_dir, wav_basename + '.npy'),
|
||||
durs)
|
||||
with open(
|
||||
os.path.join(out_feature_dir,
|
||||
wav_basename + '.phone'), 'w') as f:
|
||||
f.write('\n'.join(phone_list))
|
||||
logging.info('[AudioProcessor] Duration generate finished')
|
||||
|
||||
return True
|
||||
|
||||
def pitch_extract(self, src_wav_dir, out_f0_dir, out_frame_f0_dir,
|
||||
out_frame_uv_dir):
|
||||
os.makedirs(out_f0_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_f0_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_uv_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
mel_dict = self.mel_dict
|
||||
|
||||
logging.info('[AudioProcessor] Pitch extraction started')
|
||||
# Get raw pitch
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(
|
||||
get_pitch,
|
||||
encode_16bits(pcm_data),
|
||||
self.sampling_rate,
|
||||
self.hop_length,
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
logging.info(
|
||||
'[AudioProcessor] Pitch align with mel is proceeding...')
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch extraction failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
f0, uv, f0uv = result
|
||||
if len(mel_dict) > 0:
|
||||
f0 = align_length(f0, mel_dict.get(wav_basename, None))
|
||||
uv = align_length(uv, mel_dict.get(wav_basename, None))
|
||||
f0uv = align_length(f0uv,
|
||||
mel_dict.get(wav_basename, None))
|
||||
|
||||
if f0 is None or uv is None or f0uv is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch length mismatch with mel in %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
self.f0_dict[wav_basename] = f0
|
||||
self.uv_dict[wav_basename] = uv
|
||||
self.f0uv_dict[wav_basename] = f0uv
|
||||
|
||||
# Normalize f0
|
||||
logging.info('[AudioProcessor] Pitch normalization is proceeding...')
|
||||
f0_mean = compute_mean(list(self.f0uv_dict.values()), dims=1)
|
||||
f0_std = compute_std(list(self.f0uv_dict.values()), f0_mean, dims=1)
|
||||
np.savetxt(
|
||||
os.path.join(out_f0_dir, 'f0_mean.txt'), f0_mean, fmt='%.6f')
|
||||
np.savetxt(os.path.join(out_f0_dir, 'f0_std.txt'), f0_std, fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] f0 mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_f0_dir, 'f0_mean.txt'),
|
||||
os.path.join(out_f0_dir, 'f0_std.txt'),
|
||||
))
|
||||
logging.info('[AudioProcessor] Pitch mean std norm is proceeding...')
|
||||
for wav_basename in self.f0uv_dict:
|
||||
f0 = self.f0uv_dict[wav_basename]
|
||||
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
|
||||
self.f0uv_dict[wav_basename] = norm_f0
|
||||
|
||||
for wav_basename in self.f0_dict:
|
||||
f0 = self.f0_dict[wav_basename]
|
||||
norm_f0 = f0_norm_mean_std(f0, f0_mean, f0_std)
|
||||
self.f0_dict[wav_basename] = norm_f0
|
||||
|
||||
# save frame f0 to a specific dir
|
||||
for wav_basename in self.f0_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_f0_dir, wav_basename + '.npy'),
|
||||
self.f0_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
for wav_basename in self.uv_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_uv_dir, wav_basename + '.npy'),
|
||||
self.uv_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
# phone level average
|
||||
# if there is no duration then save the frame-level f0
|
||||
if self.phone_level_feature and len(self.dur_dict) > 0:
|
||||
logging.info(
|
||||
'[AudioProcessor] Pitch turn to phone-level is proceeding...')
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(self.f0uv_dict)) as progress:
|
||||
futures = []
|
||||
for wav_basename in self.f0uv_dict:
|
||||
future = executor.submit(
|
||||
average_by_duration,
|
||||
self.f0uv_dict.get(wav_basename, None),
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Pitch extraction failed in phone level avg for: %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
avg_f0 = result
|
||||
self.f0uv_dict[wav_basename] = avg_f0
|
||||
|
||||
for wav_basename in self.f0uv_dict:
|
||||
np.save(
|
||||
os.path.join(out_f0_dir, wav_basename + '.npy'),
|
||||
self.f0uv_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Pitch normalization finished')
|
||||
logging.info('[AudioProcessor] Normed f0 saved to %s', out_f0_dir)
|
||||
logging.info('[AudioProcessor] Pitch extraction finished')
|
||||
|
||||
return True
|
||||
|
||||
def energy_extract(self, src_wav_dir, out_energy_dir,
|
||||
out_frame_energy_dir):
|
||||
os.makedirs(out_energy_dir, exist_ok=True)
|
||||
os.makedirs(out_frame_energy_dir, exist_ok=True)
|
||||
wav_list = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
pcm_dict = self.get_pcm_dict(src_wav_dir)
|
||||
mel_dict = self.mel_dict
|
||||
|
||||
logging.info('[AudioProcessor] Energy extraction started')
|
||||
# Get raw energy
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(wav_list)) as progress:
|
||||
futures = []
|
||||
for wav_basename, pcm_data in pcm_dict.items():
|
||||
future = executor.submit(get_energy, pcm_data, self.hop_length,
|
||||
self.win_length, self.n_fft)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy extraction failed for %s',
|
||||
wav_basename)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
energy = result
|
||||
if len(mel_dict) > 0:
|
||||
energy = align_length(energy,
|
||||
mel_dict.get(wav_basename, None))
|
||||
if energy is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy length mismatch with mel in %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
continue
|
||||
self.energy_dict[wav_basename] = energy
|
||||
|
||||
logging.info('Melspec statistic proceeding...')
|
||||
# Normalize energy
|
||||
energy_mean = compute_mean(list(self.energy_dict.values()), dims=1)
|
||||
energy_std = compute_std(
|
||||
list(self.energy_dict.values()), energy_mean, dims=1)
|
||||
np.savetxt(
|
||||
os.path.join(out_energy_dir, 'energy_mean.txt'),
|
||||
energy_mean,
|
||||
fmt='%.6f')
|
||||
np.savetxt(
|
||||
os.path.join(out_energy_dir, 'energy_std.txt'),
|
||||
energy_std,
|
||||
fmt='%.6f')
|
||||
logging.info(
|
||||
'[AudioProcessor] energy mean and std saved to:\n{},\n{}'.format(
|
||||
os.path.join(out_energy_dir, 'energy_mean.txt'),
|
||||
os.path.join(out_energy_dir, 'energy_std.txt'),
|
||||
))
|
||||
|
||||
logging.info('[AudioProcessor] Energy mean std norm is proceeding...')
|
||||
for wav_basename in self.energy_dict:
|
||||
energy = self.energy_dict[wav_basename]
|
||||
norm_energy = f0_norm_mean_std(energy, energy_mean, energy_std)
|
||||
self.energy_dict[wav_basename] = norm_energy
|
||||
|
||||
# save frame energy to a specific dir
|
||||
for wav_basename in self.energy_dict:
|
||||
np.save(
|
||||
os.path.join(out_frame_energy_dir, wav_basename + '.npy'),
|
||||
self.energy_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
# phone level average
|
||||
# if there is no duration then save the frame-level energy
|
||||
if self.phone_level_feature and len(self.dur_dict) > 0:
|
||||
with ProcessPoolExecutor(
|
||||
max_workers=self.num_workers) as executor, tqdm(
|
||||
total=len(self.energy_dict)) as progress:
|
||||
futures = []
|
||||
for wav_basename in self.energy_dict:
|
||||
future = executor.submit(
|
||||
average_by_duration,
|
||||
self.energy_dict.get(wav_basename, None),
|
||||
self.dur_dict.get(wav_basename, None),
|
||||
)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append((future, wav_basename))
|
||||
|
||||
for future, wav_basename in futures:
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logging.warning(
|
||||
'[AudioProcessor] Energy extraction failed in phone level avg for: %s',
|
||||
wav_basename,
|
||||
)
|
||||
self.badcase_list.append(wav_basename)
|
||||
else:
|
||||
avg_energy = result
|
||||
self.energy_dict[wav_basename] = avg_energy
|
||||
|
||||
for wav_basename in self.energy_dict:
|
||||
np.save(
|
||||
os.path.join(out_energy_dir, wav_basename + '.npy'),
|
||||
self.energy_dict[wav_basename].reshape(-1),
|
||||
)
|
||||
|
||||
logging.info('[AudioProcessor] Energy normalization finished')
|
||||
logging.info('[AudioProcessor] Normed Energy saved to %s',
|
||||
out_energy_dir)
|
||||
logging.info('[AudioProcessor] Energy extraction finished')
|
||||
|
||||
return True
|
||||
|
||||
def process(self, src_voice_dir, out_data_dir, aux_metafile=None):
|
||||
succeed = True
|
||||
|
||||
raw_wav_dir = os.path.join(src_voice_dir, 'wav')
|
||||
src_interval_dir = os.path.join(src_voice_dir, 'interval')
|
||||
|
||||
out_mel_dir = os.path.join(out_data_dir, 'mel')
|
||||
out_f0_dir = os.path.join(out_data_dir, 'f0')
|
||||
out_frame_f0_dir = os.path.join(out_data_dir, 'frame_f0')
|
||||
out_frame_uv_dir = os.path.join(out_data_dir, 'frame_uv')
|
||||
out_energy_dir = os.path.join(out_data_dir, 'energy')
|
||||
out_frame_energy_dir = os.path.join(out_data_dir, 'frame_energy')
|
||||
out_duration_dir = os.path.join(out_data_dir, 'raw_duration')
|
||||
out_cali_duration_dir = os.path.join(out_data_dir, 'duration')
|
||||
|
||||
os.makedirs(out_data_dir, exist_ok=True)
|
||||
|
||||
with_duration = os.path.exists(src_interval_dir)
|
||||
train_wav_dir = os.path.join(out_data_dir, 'wav')
|
||||
|
||||
succeed = self.amp_normalize(raw_wav_dir, train_wav_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] amp_normalize failed, exit')
|
||||
return False
|
||||
|
||||
if with_duration:
|
||||
# Raw duration, non-trimmed
|
||||
succeed = self.duration_generate(src_interval_dir,
|
||||
out_duration_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] duration_generate failed, exit')
|
||||
return False
|
||||
|
||||
if self.trim_silence:
|
||||
if with_duration:
|
||||
succeed = self.trim_silence_wav_with_interval(
|
||||
train_wav_dir, out_duration_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] trim_silence_wav_with_interval failed, exit'
|
||||
)
|
||||
return False
|
||||
else:
|
||||
succeed = self.trim_silence_wav(train_wav_dir)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'[AudioProcessor] trim_silence_wav failed, exit')
|
||||
return False
|
||||
|
||||
succeed = self.mel_extract(train_wav_dir, out_mel_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] mel_extract failed, exit')
|
||||
return False
|
||||
|
||||
if aux_metafile is not None and with_duration:
|
||||
self.calibrate_SyllableDuration(out_duration_dir, aux_metafile,
|
||||
out_cali_duration_dir)
|
||||
|
||||
succeed = self.pitch_extract(train_wav_dir, out_f0_dir,
|
||||
out_frame_f0_dir, out_frame_uv_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] pitch_extract failed, exit')
|
||||
return False
|
||||
|
||||
succeed = self.energy_extract(train_wav_dir, out_energy_dir,
|
||||
out_frame_energy_dir)
|
||||
if not succeed:
|
||||
logging.error('[AudioProcessor] energy_extract failed, exit')
|
||||
return False
|
||||
|
||||
# recording badcase list
|
||||
with open(os.path.join(out_data_dir, 'badlist.txt'), 'w') as f:
|
||||
f.write('\n'.join(self.badcase_list))
|
||||
|
||||
logging.info('[AudioProcessor] All features extracted successfully!')
|
||||
|
||||
return succeed
|
||||
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
|
||||
|
||||
def _stft(y, hop_length, win_length, n_fft):
|
||||
return librosa.stft(
|
||||
y=y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _istft(y, hop_length, win_length):
|
||||
return librosa.istft(y, hop_length=hop_length, win_length=win_length)
|
||||
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, x * 0.05)
|
||||
|
||||
|
||||
def _amp_to_db(x):
|
||||
return 20 * np.log10(np.maximum(1e-5, x))
|
||||
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.load(path, sr=sr)[0]
|
||||
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
if wav.dtype == np.float32 or wav.dtype == np.float64:
|
||||
quant_wav = 32767 * wav
|
||||
else:
|
||||
quant_wav = wav
|
||||
# maxmize the volume to avoid clipping
|
||||
# wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
wavfile.write(path, sr, quant_wav.astype(np.int16))
|
||||
|
||||
|
||||
def trim_silence(wav, top_db, hop_length, win_length):
|
||||
trimed_wav, _ = librosa.effects.trim(
|
||||
wav, top_db=top_db, frame_length=win_length, hop_length=hop_length)
|
||||
return trimed_wav
|
||||
|
||||
|
||||
def trim_silence_with_interval(wav, interval, hop_length):
|
||||
if interval is None:
|
||||
return None
|
||||
leading_sil = interval[0]
|
||||
tailing_sil = interval[-1]
|
||||
trim_wav = wav[leading_sil * hop_length:-tailing_sil * hop_length]
|
||||
return trim_wav
|
||||
|
||||
|
||||
def preemphasis(wav, k=0.98, preemphasize=False):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def inv_preemphasis(wav, k=0.98, inv_preemphasize=False):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
|
||||
def _normalize(S, max_norm=1.0, min_level_db=-100, symmetric=False):
|
||||
if symmetric:
|
||||
return np.clip(
|
||||
(2 * max_norm) * ((S - min_level_db) / (-min_level_db)) - max_norm,
|
||||
-max_norm,
|
||||
max_norm,
|
||||
)
|
||||
else:
|
||||
return np.clip(max_norm * ((S - min_level_db) / (-min_level_db)), 0,
|
||||
max_norm)
|
||||
|
||||
|
||||
def _denormalize(D, max_norm=1.0, min_level_db=-100, symmetric=False):
|
||||
if symmetric:
|
||||
return ((np.clip(D, -max_norm, max_norm) + max_norm) * -min_level_db
|
||||
/ # noqa W504
|
||||
(2 * max_norm)) + min_level_db
|
||||
else:
|
||||
return (np.clip(D, 0, max_norm) * -min_level_db
|
||||
/ max_norm) + min_level_db
|
||||
|
||||
|
||||
def _griffin_lim(S, n_fft, hop_length, win_length, griffin_lim_iters=60):
|
||||
angles = np.exp(2j * np.pi * np.random.rand(*S.shape))
|
||||
S_complex = np.abs(S).astype(np.complex)
|
||||
y = _istft(
|
||||
S_complex * angles, hop_length=hop_length, win_length=win_length)
|
||||
for i in range(griffin_lim_iters):
|
||||
angles = np.exp(1j * np.angle(
|
||||
_stft(
|
||||
y, n_fft=n_fft, hop_length=hop_length, win_length=win_length)))
|
||||
y = _istft(
|
||||
S_complex * angles, hop_length=hop_length, win_length=win_length)
|
||||
return y
|
||||
|
||||
|
||||
def spectrogram(
|
||||
y,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
symmetric=False,
|
||||
):
|
||||
D = _stft(preemphasis(y), hop_length, win_length, n_fft)
|
||||
S = _amp_to_db(np.abs(D)) - ref_level_db
|
||||
return _normalize(S, max_norm, min_level_db, symmetric)
|
||||
|
||||
|
||||
def inv_spectrogram(
|
||||
spectrogram,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
symmetric=False,
|
||||
power=1.5,
|
||||
):
|
||||
S = _db_to_amp(
|
||||
_denormalize(spectrogram, max_norm, min_level_db, symmetric)
|
||||
+ ref_level_db)
|
||||
return _griffin_lim(S**power, n_fft, hop_length, win_length)
|
||||
|
||||
|
||||
def _build_mel_basis(sample_rate, n_fft=1024, fmin=50, fmax=8000, n_mels=80):
|
||||
assert fmax <= sample_rate // 2
|
||||
return librosa.filters.mel(
|
||||
sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
||||
|
||||
|
||||
# mel linear Conversions
|
||||
_mel_basis = None
|
||||
_inv_mel_basis = None
|
||||
|
||||
|
||||
def _linear_to_mel(spectogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
n_mels=80):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels)
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
|
||||
def _mel_to_linear(mel_spectrogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
n_mels=80):
|
||||
global _inv_mel_basis
|
||||
if _inv_mel_basis is None:
|
||||
_inv_mel_basis = np.linalg.pinv(
|
||||
_build_mel_basis(sample_rate, n_fft, fmin, fmax, n_mels))
|
||||
return np.maximum(1e-10, np.dot(_inv_mel_basis, mel_spectrogram))
|
||||
|
||||
|
||||
def melspectrogram(
|
||||
y,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mels=80,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
symmetric=False,
|
||||
preemphasize=False,
|
||||
):
|
||||
D = _stft(
|
||||
preemphasis(y, preemphasize=preemphasize),
|
||||
hop_length=hop_length,
|
||||
win_length=win_length,
|
||||
n_fft=n_fft,
|
||||
)
|
||||
S = (
|
||||
_amp_to_db(
|
||||
_linear_to_mel(
|
||||
np.abs(D),
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_mels=n_mels,
|
||||
)) - ref_level_db)
|
||||
return _normalize(
|
||||
S, max_norm=max_norm, min_level_db=min_level_db, symmetric=symmetric).T
|
||||
|
||||
|
||||
def inv_mel_spectrogram(
|
||||
mel_spectrogram,
|
||||
sample_rate,
|
||||
n_fft=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mels=80,
|
||||
max_norm=1.0,
|
||||
min_level_db=-100,
|
||||
ref_level_db=20,
|
||||
fmin=50,
|
||||
fmax=8000,
|
||||
power=1.5,
|
||||
symmetric=False,
|
||||
preemphasize=False,
|
||||
):
|
||||
D = _denormalize(
|
||||
mel_spectrogram,
|
||||
max_norm=max_norm,
|
||||
min_level_db=min_level_db,
|
||||
symmetric=symmetric,
|
||||
)
|
||||
S = _mel_to_linear(
|
||||
_db_to_amp(D + ref_level_db),
|
||||
sample_rate=sample_rate,
|
||||
n_fft=n_fft,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
n_mels=n_mels,
|
||||
)
|
||||
return inv_preemphasis(
|
||||
_griffin_lim(S**power, n_fft, hop_length, win_length),
|
||||
preemphasize=preemphasize,
|
||||
)
|
||||
@@ -0,0 +1,480 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from concurrent.futures import ProcessPoolExecutor
|
||||
from glob import glob
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import pysptk
|
||||
import sox
|
||||
from scipy.io import wavfile
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .dsp import _stft
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
anchor_hist = np.array([
|
||||
0.0,
|
||||
0.00215827,
|
||||
0.00354383,
|
||||
0.00442313,
|
||||
0.00490274,
|
||||
0.00532907,
|
||||
0.00602185,
|
||||
0.00690115,
|
||||
0.00810019,
|
||||
0.00948574,
|
||||
0.0120437,
|
||||
0.01489475,
|
||||
0.01873168,
|
||||
0.02302158,
|
||||
0.02872369,
|
||||
0.03669065,
|
||||
0.04636291,
|
||||
0.05843325,
|
||||
0.07700506,
|
||||
0.11052491,
|
||||
0.16802558,
|
||||
0.25997868,
|
||||
0.37942979,
|
||||
0.50730083,
|
||||
0.62006395,
|
||||
0.71092459,
|
||||
0.76877165,
|
||||
0.80762057,
|
||||
0.83458566,
|
||||
0.85672795,
|
||||
0.87660538,
|
||||
0.89251266,
|
||||
0.90578204,
|
||||
0.91569411,
|
||||
0.92541966,
|
||||
0.93383959,
|
||||
0.94162004,
|
||||
0.94940048,
|
||||
0.95539568,
|
||||
0.96136424,
|
||||
0.9670397,
|
||||
0.97290168,
|
||||
0.97705835,
|
||||
0.98116174,
|
||||
0.98465228,
|
||||
0.98814282,
|
||||
0.99152678,
|
||||
0.99421796,
|
||||
0.9965894,
|
||||
0.99840128,
|
||||
1.0,
|
||||
])
|
||||
|
||||
anchor_bins = np.array([
|
||||
0.033976,
|
||||
0.03529014,
|
||||
0.03660428,
|
||||
0.03791842,
|
||||
0.03923256,
|
||||
0.0405467,
|
||||
0.04186084,
|
||||
0.04317498,
|
||||
0.04448912,
|
||||
0.04580326,
|
||||
0.0471174,
|
||||
0.04843154,
|
||||
0.04974568,
|
||||
0.05105982,
|
||||
0.05237396,
|
||||
0.0536881,
|
||||
0.05500224,
|
||||
0.05631638,
|
||||
0.05763052,
|
||||
0.05894466,
|
||||
0.0602588,
|
||||
0.06157294,
|
||||
0.06288708,
|
||||
0.06420122,
|
||||
0.06551536,
|
||||
0.0668295,
|
||||
0.06814364,
|
||||
0.06945778,
|
||||
0.07077192,
|
||||
0.07208606,
|
||||
0.0734002,
|
||||
0.07471434,
|
||||
0.07602848,
|
||||
0.07734262,
|
||||
0.07865676,
|
||||
0.0799709,
|
||||
0.08128504,
|
||||
0.08259918,
|
||||
0.08391332,
|
||||
0.08522746,
|
||||
0.0865416,
|
||||
0.08785574,
|
||||
0.08916988,
|
||||
0.09048402,
|
||||
0.09179816,
|
||||
0.0931123,
|
||||
0.09442644,
|
||||
0.09574058,
|
||||
0.09705472,
|
||||
0.09836886,
|
||||
0.099683,
|
||||
])
|
||||
|
||||
hist_bins = 50
|
||||
|
||||
|
||||
def amp_info(wav_file_path):
|
||||
"""
|
||||
Returns the amplitude info of the wav file.
|
||||
"""
|
||||
stats = sox.file_info.stat(wav_file_path)
|
||||
amp_rms = stats['RMS amplitude']
|
||||
amp_max = stats['Maximum amplitude']
|
||||
amp_mean = stats['Mean amplitude']
|
||||
length = stats['Length (seconds)']
|
||||
|
||||
return {
|
||||
'amp_rms': amp_rms,
|
||||
'amp_max': amp_max,
|
||||
'amp_mean': amp_mean,
|
||||
'length': length,
|
||||
'basename': os.path.basename(wav_file_path),
|
||||
}
|
||||
|
||||
|
||||
def statistic_amplitude(src_wav_dir):
|
||||
"""
|
||||
Returns the amplitude info of the wav file.
|
||||
"""
|
||||
wav_lst = glob(os.path.join(src_wav_dir, '*.wav'))
|
||||
with ProcessPoolExecutor(max_workers=8) as executor, tqdm(
|
||||
total=len(wav_lst)) as progress:
|
||||
futures = []
|
||||
for wav_file_path in wav_lst:
|
||||
future = executor.submit(amp_info, wav_file_path)
|
||||
future.add_done_callback(lambda p: progress.update())
|
||||
futures.append(future)
|
||||
|
||||
amp_info_lst = [future.result() for future in futures]
|
||||
|
||||
amp_info_lst = sorted(amp_info_lst, key=lambda x: x['amp_rms'])
|
||||
|
||||
logging.info('Average amplitude RMS : {}'.format(
|
||||
np.mean([x['amp_rms'] for x in amp_info_lst])))
|
||||
|
||||
return amp_info_lst
|
||||
|
||||
|
||||
def volume_normalize(src_wav_dir, out_wav_dir):
|
||||
logging.info('Volume statistic proceeding...')
|
||||
amp_info_lst = statistic_amplitude(src_wav_dir)
|
||||
logging.info('Volume statistic done.')
|
||||
|
||||
rms_amp_lst = [x['amp_rms'] for x in amp_info_lst]
|
||||
src_hist, src_bins = np.histogram(
|
||||
rms_amp_lst, bins=hist_bins, density=True)
|
||||
src_hist = src_hist / np.sum(src_hist)
|
||||
src_hist = np.cumsum(src_hist)
|
||||
src_hist = np.insert(src_hist, 0, 0.0)
|
||||
|
||||
logging.info('Volume normalization proceeding...')
|
||||
for amp_info in tqdm(amp_info_lst):
|
||||
rms_amp = amp_info['amp_rms']
|
||||
rms_amp = np.clip(rms_amp, src_bins[0], src_bins[-1])
|
||||
|
||||
src_idx = np.where(rms_amp >= src_bins)[0][-1]
|
||||
src_pos = src_hist[src_idx]
|
||||
anchor_idx = np.where(src_pos >= anchor_hist)[0][-1]
|
||||
|
||||
if src_idx == hist_bins or anchor_idx == hist_bins:
|
||||
rms_amp = anchor_bins[-1]
|
||||
else:
|
||||
rms_amp = (rms_amp - src_bins[src_idx]) / (
|
||||
src_bins[src_idx + 1] - src_bins[src_idx]) * (
|
||||
anchor_bins[anchor_idx + 1]
|
||||
- anchor_bins[anchor_idx]) + anchor_bins[anchor_idx]
|
||||
|
||||
scale = rms_amp / amp_info['amp_rms']
|
||||
|
||||
# FIXME: This is a hack to avoid the sound cliping.
|
||||
sr, data = wavfile.read(
|
||||
os.path.join(src_wav_dir, amp_info['basename']))
|
||||
wavfile.write(
|
||||
os.path.join(out_wav_dir, amp_info['basename']),
|
||||
sr,
|
||||
(data * scale).astype(np.int16),
|
||||
)
|
||||
logging.info('Volume normalization done.')
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def interp_f0(f0_data):
|
||||
"""
|
||||
linear interpolation
|
||||
"""
|
||||
f0_data[f0_data < 1] = 0
|
||||
xp = np.nonzero(f0_data)
|
||||
yp = f0_data[xp]
|
||||
x = np.arange(f0_data.size)
|
||||
contour_f0 = np.interp(x, xp[0], yp).astype(np.float32)
|
||||
return contour_f0
|
||||
|
||||
|
||||
def frame_nccf(x, y):
|
||||
norm_coef = (np.sum(x**2.0) * np.sum(y**2.0) + 1e-30)**0.5
|
||||
return (np.sum(x * y) / norm_coef + 1.0) / 2.0
|
||||
|
||||
|
||||
def get_nccf(pcm_data, f0, min_f0=40, max_f0=800, fs=160, sr=16000):
|
||||
if pcm_data.dtype == np.int16:
|
||||
pcm_data = pcm_data.astype(np.float32) / 32768
|
||||
frame_len = int(sr / 200)
|
||||
frame_num = int(len(pcm_data) // fs)
|
||||
frame_num = min(frame_num, len(f0))
|
||||
|
||||
pad_len = int(sr / min_f0) + frame_len
|
||||
|
||||
pad_zeros = np.zeros([pad_len], dtype=np.float32)
|
||||
data = np.hstack((pad_zeros, pcm_data.astype(np.float32), pad_zeros))
|
||||
|
||||
nccf = np.zeros((frame_num), dtype=np.float32)
|
||||
|
||||
for i in range(frame_num):
|
||||
curr_f0 = np.clip(f0[i], min_f0, max_f0)
|
||||
lag = int(sr / curr_f0 + 0.5)
|
||||
j = i * fs + pad_len - frame_len // 2
|
||||
|
||||
l_data = data[j:j + frame_len]
|
||||
l_data -= l_data.mean()
|
||||
|
||||
r_data = data[j + lag:j + lag + frame_len]
|
||||
r_data -= r_data.mean()
|
||||
|
||||
nccf[i] = frame_nccf(l_data, r_data)
|
||||
|
||||
return nccf
|
||||
|
||||
|
||||
def smooth(data, win_len):
|
||||
if win_len % 2 == 0:
|
||||
win_len += 1
|
||||
hwin = win_len // 2
|
||||
win = np.hanning(win_len)
|
||||
win /= win.sum()
|
||||
data = data.reshape([-1])
|
||||
pad_data = np.pad(data, hwin, mode='edge')
|
||||
|
||||
for i in range(data.shape[0]):
|
||||
data[i] = np.dot(win, pad_data[i:i + win_len])
|
||||
|
||||
return data.reshape([-1, 1])
|
||||
|
||||
|
||||
# support: rapt, swipe
|
||||
# unsupport: reaper, world(DIO)
|
||||
def RAPT_FUNC(v1, v2, v3, v4, v5):
|
||||
return pysptk.sptk.rapt(
|
||||
v1.astype(np.float32), fs=v2, hopsize=v3, min=v4, max=v5)
|
||||
|
||||
|
||||
def SWIPE_FUNC(v1, v2, v3, v4, v5):
|
||||
return pysptk.sptk.swipe(
|
||||
v1.astype(np.float64), fs=v2, hopsize=v3, min=v4, max=v5)
|
||||
|
||||
|
||||
def PYIN_FUNC(v1, v2, v3, v4, v5):
|
||||
f0_mel = librosa.pyin(
|
||||
v1.astype(np.float32), sr=v2, frame_length=v3 * 4, fmin=v4, fmax=v5)[0]
|
||||
f0_mel = np.where(np.isnan(f0_mel), 0.0, f0_mel)
|
||||
return f0_mel
|
||||
|
||||
|
||||
def get_pitch(pcm_data, sampling_rate=16000, hop_length=160):
|
||||
log_f0_list = []
|
||||
uv_list = []
|
||||
low, high = 40, 800
|
||||
|
||||
cali_f0 = pysptk.sptk.rapt(
|
||||
pcm_data.astype(np.float32),
|
||||
fs=sampling_rate,
|
||||
hopsize=hop_length,
|
||||
min=low,
|
||||
max=high,
|
||||
)
|
||||
f0_range = np.sort(np.unique(cali_f0))
|
||||
|
||||
if len(f0_range) > 20:
|
||||
low = max(f0_range[10] - 50, low)
|
||||
high = min(f0_range[-10] + 50, high)
|
||||
|
||||
func_dict = {'rapt': RAPT_FUNC, 'swipe': SWIPE_FUNC}
|
||||
|
||||
for func_name in func_dict:
|
||||
f0 = func_dict[func_name](pcm_data, sampling_rate, hop_length, low,
|
||||
high)
|
||||
uv = f0 > 0
|
||||
|
||||
if len(f0) < 10 or f0.max() < low:
|
||||
logging.error('{} method: calc F0 is too low.'.format(func_name))
|
||||
continue
|
||||
else:
|
||||
f0 = np.clip(f0, 1e-30, high)
|
||||
log_f0 = np.log(f0)
|
||||
contour_log_f0 = interp_f0(log_f0)
|
||||
|
||||
log_f0_list.append(contour_log_f0)
|
||||
uv_list.append(uv)
|
||||
|
||||
if len(log_f0_list) == 0:
|
||||
logging.error('F0 estimation failed.')
|
||||
return None
|
||||
|
||||
min_len = float('inf')
|
||||
for log_f0 in log_f0_list:
|
||||
min_len = min(min_len, log_f0.shape[0])
|
||||
|
||||
multi_log_f0 = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
|
||||
multi_uv = np.zeros([len(log_f0_list), min_len], dtype=np.float32)
|
||||
|
||||
for i in range(len(log_f0_list)):
|
||||
multi_log_f0[i, :] = log_f0_list[i][:min_len]
|
||||
multi_uv[i, :] = uv_list[i][:min_len]
|
||||
|
||||
log_f0 = smooth(np.median(multi_log_f0, axis=0), 5)
|
||||
uv = (smooth(np.median(multi_uv, axis=0), 5) > 0.5).astype(np.float32)
|
||||
|
||||
f0 = np.exp(log_f0)
|
||||
|
||||
min_len = min(f0.shape[0], uv.shape[0])
|
||||
|
||||
return f0[:min_len], uv[:min_len], f0[:min_len] * uv[:min_len]
|
||||
|
||||
|
||||
def get_energy(pcm_data, hop_length, win_length, n_fft):
|
||||
D = _stft(pcm_data, hop_length, win_length, n_fft)
|
||||
S, _ = librosa.magphase(D)
|
||||
energy = np.sqrt(np.sum(S**2, axis=0))
|
||||
|
||||
return energy.reshape((-1, 1))
|
||||
|
||||
|
||||
def align_length(in_data, tgt_data, basename=None):
|
||||
if in_data is None or tgt_data is None:
|
||||
logging.error('{}: Input data is None.'.format(basename))
|
||||
return None
|
||||
in_len = in_data.shape[0]
|
||||
tgt_len = tgt_data.shape[0]
|
||||
if abs(in_len - tgt_len) > 20:
|
||||
logging.error(
|
||||
'{}: Input data length mismatches with target data length too much.'
|
||||
.format(basename))
|
||||
return None
|
||||
|
||||
if in_len < tgt_len:
|
||||
out_data = np.pad(
|
||||
in_data, ((0, tgt_len - in_len), (0, 0)),
|
||||
'constant',
|
||||
constant_values=0.0)
|
||||
else:
|
||||
out_data = in_data[:tgt_len]
|
||||
|
||||
return out_data
|
||||
|
||||
|
||||
def compute_mean(data_list, dims=80):
|
||||
mean_vector = np.zeros((1, dims))
|
||||
all_frame_number = 0
|
||||
|
||||
for data in tqdm(data_list):
|
||||
if data is None:
|
||||
continue
|
||||
features = data.reshape((-1, dims))
|
||||
current_frame_number = np.shape(features)[0]
|
||||
mean_vector += np.sum(features[:, :], axis=0)
|
||||
all_frame_number += current_frame_number
|
||||
|
||||
mean_vector /= float(all_frame_number)
|
||||
return mean_vector
|
||||
|
||||
|
||||
def compute_std(data_list, mean_vector, dims=80):
|
||||
std_vector = np.zeros((1, dims))
|
||||
all_frame_number = 0
|
||||
|
||||
for data in tqdm(data_list):
|
||||
if data is None:
|
||||
continue
|
||||
features = data.reshape((-1, dims))
|
||||
current_frame_number = np.shape(features)[0]
|
||||
mean_matrix = np.tile(mean_vector, (current_frame_number, 1))
|
||||
std_vector += np.sum((features[:, :] - mean_matrix)**2, axis=0)
|
||||
all_frame_number += current_frame_number
|
||||
|
||||
std_vector /= float(all_frame_number)
|
||||
std_vector = std_vector**0.5
|
||||
return std_vector
|
||||
|
||||
|
||||
def f0_norm_mean_std(x, mean, std):
|
||||
zero_idxs = np.where(x == 0.0)[0]
|
||||
x = (x - mean) / std
|
||||
x[zero_idxs] = 0.0
|
||||
return x
|
||||
|
||||
|
||||
def norm_mean_std(x, mean, std):
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
def parse_interval_file(file_path, sampling_rate, hop_length):
|
||||
with open(file_path, 'r') as f:
|
||||
lines = f.readlines()
|
||||
# second
|
||||
frame_intervals = 1.0 * hop_length / sampling_rate
|
||||
skip_lines = 12
|
||||
dur_list = []
|
||||
phone_list = []
|
||||
|
||||
line_index = skip_lines
|
||||
|
||||
while line_index < len(lines):
|
||||
phone_begin = float(lines[line_index])
|
||||
phone_end = float(lines[line_index + 1])
|
||||
phone = lines[line_index + 2].strip()[1:-1]
|
||||
dur_list.append(
|
||||
int(round((phone_end - phone_begin) / frame_intervals)))
|
||||
phone_list.append(phone)
|
||||
line_index += 3
|
||||
|
||||
if len(dur_list) == 0 or len(phone_list) == 0:
|
||||
return None
|
||||
|
||||
return np.array(dur_list), phone_list
|
||||
|
||||
|
||||
def average_by_duration(x, durs):
|
||||
if x is None or durs is None:
|
||||
return None
|
||||
durs_cum = np.cumsum(np.pad(durs, (1, 0), 'constant'))
|
||||
|
||||
# average over each symbol's duraion
|
||||
x_symbol = np.zeros((durs.shape[0], ), dtype=np.float32)
|
||||
for idx, start, end in zip(
|
||||
range(durs.shape[0]), durs_cum[:-1], durs_cum[1:]):
|
||||
values = x[start:end][np.where(x[start:end] != 0.0)[0]]
|
||||
x_symbol[idx] = np.mean(values) if len(values) > 0 else 0.0
|
||||
|
||||
return x_symbol.astype(np.float32)
|
||||
|
||||
|
||||
def encode_16bits(x):
|
||||
if x.min() > -1.0 and x.max() < 1.0:
|
||||
return np.clip(x * 2**15, -(2**15), 2**15 - 1).astype(np.int16)
|
||||
else:
|
||||
return x
|
||||
177
modelscope/models/audio/tts/kantts/preprocess/data_process.py
Normal file
177
modelscope/models/audio/tts/kantts/preprocess/data_process.py
Normal file
@@ -0,0 +1,177 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import codecs
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import yaml
|
||||
|
||||
from modelscope import __version__
|
||||
from modelscope.models.audio.tts.kantts.datasets.dataset import (AmDataset,
|
||||
VocDataset)
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .audio_processor.audio_processor import AudioProcessor
|
||||
from .fp_processor import FpProcessor, is_fp_line
|
||||
from .languages import languages
|
||||
from .script_convertor.text_script_convertor import TextScriptConvertor
|
||||
|
||||
ROOT_PATH = os.path.dirname(os.path.dirname(
|
||||
os.path.abspath(__file__))) # NOQA: E402
|
||||
sys.path.insert(0, os.path.dirname(ROOT_PATH)) # NOQA: E402
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
def gen_metafile(
|
||||
voice_output_dir,
|
||||
fp_enable=False,
|
||||
badlist=None,
|
||||
split_ratio=0.98,
|
||||
):
|
||||
|
||||
voc_train_meta = os.path.join(voice_output_dir, 'train.lst')
|
||||
voc_valid_meta = os.path.join(voice_output_dir, 'valid.lst')
|
||||
if not os.path.exists(voc_train_meta) or not os.path.exists(
|
||||
voc_valid_meta):
|
||||
VocDataset.gen_metafile(
|
||||
os.path.join(voice_output_dir, 'wav'),
|
||||
voice_output_dir,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('Voc metafile generated.')
|
||||
|
||||
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
raw_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM metafile generated.')
|
||||
|
||||
if fp_enable:
|
||||
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_fpadd_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_fpadd_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(
|
||||
am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
fpadd_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM fpaddmetafile generated.')
|
||||
|
||||
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
|
||||
am_train_meta = os.path.join(voice_output_dir, 'am_fprm_train.lst')
|
||||
am_valid_meta = os.path.join(voice_output_dir, 'am_fprm_valid.lst')
|
||||
if not os.path.exists(am_train_meta) or not os.path.exists(
|
||||
am_valid_meta):
|
||||
AmDataset.gen_metafile(
|
||||
fprm_metafile,
|
||||
voice_output_dir,
|
||||
am_train_meta,
|
||||
am_valid_meta,
|
||||
badlist,
|
||||
split_ratio,
|
||||
)
|
||||
logging.info('AM fprmmetafile generated.')
|
||||
|
||||
|
||||
def process_data(
|
||||
voice_input_dir,
|
||||
voice_output_dir,
|
||||
language_dir,
|
||||
audio_config,
|
||||
speaker_name=None,
|
||||
targetLang='PinYin',
|
||||
skip_script=False,
|
||||
):
|
||||
foreignLang = 'EnUS'
|
||||
emo_tag_path = None
|
||||
|
||||
phoneset_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['phoneset_path'])
|
||||
posset_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['posset_path'])
|
||||
f2t_map_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['f2t_map_path'])
|
||||
s2p_map_path = os.path.join(language_dir, targetLang,
|
||||
languages[targetLang]['s2p_map_path'])
|
||||
|
||||
logging.info(f'phoneset_path={phoneset_path}')
|
||||
|
||||
if speaker_name is None:
|
||||
speaker_name = os.path.basename(voice_input_dir)
|
||||
|
||||
if audio_config is not None:
|
||||
with open(audio_config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime())
|
||||
config['modelscope_version'] = __version__
|
||||
|
||||
with open(os.path.join(voice_output_dir, 'audio_config.yaml'), 'w') as f:
|
||||
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
|
||||
|
||||
if skip_script:
|
||||
logging.info('Skip script conversion')
|
||||
raw_metafile = None
|
||||
# Script processor
|
||||
if not skip_script:
|
||||
tsc = TextScriptConvertor(
|
||||
phoneset_path,
|
||||
posset_path,
|
||||
targetLang,
|
||||
foreignLang,
|
||||
f2t_map_path,
|
||||
s2p_map_path,
|
||||
emo_tag_path,
|
||||
speaker_name,
|
||||
)
|
||||
tsc.process(
|
||||
os.path.join(voice_input_dir, 'prosody', 'prosody.txt'),
|
||||
os.path.join(voice_output_dir, 'Script.xml'),
|
||||
os.path.join(voice_output_dir, 'raw_metafile.txt'),
|
||||
)
|
||||
raw_metafile = os.path.join(voice_output_dir, 'raw_metafile.txt')
|
||||
prosody = os.path.join(voice_input_dir, 'prosody', 'prosody.txt')
|
||||
|
||||
# FP processor
|
||||
with codecs.open(prosody, 'r', 'utf-8') as f:
|
||||
lines = f.readlines()
|
||||
fp_enable = is_fp_line(lines[1])
|
||||
|
||||
if fp_enable:
|
||||
FP = FpProcessor()
|
||||
|
||||
FP.process(
|
||||
voice_output_dir,
|
||||
prosody,
|
||||
raw_metafile,
|
||||
)
|
||||
logging.info('Processing fp done.')
|
||||
|
||||
# Audio processor
|
||||
ap = AudioProcessor(config['audio_config'])
|
||||
ap.process(
|
||||
voice_input_dir,
|
||||
voice_output_dir,
|
||||
raw_metafile,
|
||||
)
|
||||
|
||||
logging.info('Processing done.')
|
||||
|
||||
# Generate Voc&AM metafile
|
||||
gen_metafile(voice_output_dir, fp_enable, ap.badcase_list)
|
||||
128
modelscope/models/audio/tts/kantts/preprocess/fp_processor.py
Normal file
128
modelscope/models/audio/tts/kantts/preprocess/fp_processor.py
Normal file
@@ -0,0 +1,128 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import random
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
def is_fp_line(line):
|
||||
fp_category_list = ['FP', 'I', 'N', 'Q']
|
||||
elements = line.strip().split(' ')
|
||||
res = True
|
||||
for ele in elements:
|
||||
if ele not in fp_category_list:
|
||||
res = False
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
class FpProcessor:
|
||||
|
||||
def __init__(self):
|
||||
self.res = []
|
||||
|
||||
def addfp(self, voice_output_dir, prosody, raw_metafile_lines):
|
||||
|
||||
fp_category_list = ['FP', 'I', 'N']
|
||||
|
||||
f = open(prosody)
|
||||
prosody_lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
idx = ''
|
||||
fp = ''
|
||||
fp_label_dict = {}
|
||||
for i in range(len(prosody_lines)):
|
||||
if i % 5 == 0:
|
||||
idx = prosody_lines[i].strip().split('\t')[0]
|
||||
elif i % 5 == 1: # according to prosody.txt
|
||||
fp = prosody_lines[i].strip().split('\t')[0].split(' ')
|
||||
for label in fp:
|
||||
if label not in fp_category_list:
|
||||
logging.warning('fp label not in fp_category_list')
|
||||
break
|
||||
fp_label_dict[idx] = fp
|
||||
|
||||
fpadd_metafile = os.path.join(voice_output_dir, 'fpadd_metafile.txt')
|
||||
f_out = open(fpadd_metafile, 'w')
|
||||
for line in raw_metafile_lines:
|
||||
tokens = line.strip().split('\t')
|
||||
if len(tokens) == 2:
|
||||
uttname = tokens[0]
|
||||
symbol_sequences = tokens[1].split(' ')
|
||||
|
||||
error_flag = False
|
||||
idx = 0
|
||||
out_str = uttname + '\t'
|
||||
|
||||
for this_symbol_sequence in symbol_sequences:
|
||||
emotion = this_symbol_sequence.split('$')[4]
|
||||
this_symbol_sequence = this_symbol_sequence.replace(
|
||||
emotion, 'emotion_neutral')
|
||||
|
||||
if idx < len(fp_label_dict[uttname]):
|
||||
if fp_label_dict[uttname][idx] == 'FP':
|
||||
if 'none' not in this_symbol_sequence:
|
||||
this_symbol_sequence = this_symbol_sequence.replace(
|
||||
'emotion_neutral', 'emotion_disgust')
|
||||
syllable_label = this_symbol_sequence.split('$')[2]
|
||||
if syllable_label == 's_both' or syllable_label == 's_end':
|
||||
idx += 1
|
||||
elif idx > len(fp_label_dict[uttname]):
|
||||
logging.warning(uttname + ' not match')
|
||||
error_flag = True
|
||||
out_str = out_str + this_symbol_sequence + ' '
|
||||
|
||||
if idx != len(fp_label_dict[uttname]):
|
||||
logging.warning('{} length mismatch, length: {} '.format(
|
||||
idx, len(fp_label_dict[uttname])))
|
||||
|
||||
if not error_flag:
|
||||
f_out.write(out_str.strip() + '\n')
|
||||
f_out.close()
|
||||
return fpadd_metafile
|
||||
|
||||
def removefp(self, voice_output_dir, fpadd_metafile, raw_metafile_lines):
|
||||
|
||||
f = open(fpadd_metafile)
|
||||
fpadd_metafile_lines = f.readlines()
|
||||
f.close()
|
||||
|
||||
fprm_metafile = os.path.join(voice_output_dir, 'fprm_metafile.txt')
|
||||
f_out = open(fprm_metafile, 'w')
|
||||
for i in range(len(raw_metafile_lines)):
|
||||
tokens = raw_metafile_lines[i].strip().split('\t')
|
||||
symbol_sequences = tokens[1].split(' ')
|
||||
fpadd_tokens = fpadd_metafile_lines[i].strip().split('\t')
|
||||
fpadd_symbol_sequences = fpadd_tokens[1].split(' ')
|
||||
|
||||
error_flag = False
|
||||
out_str = tokens[0] + '\t'
|
||||
idx = 0
|
||||
length = len(symbol_sequences)
|
||||
while idx < length:
|
||||
if '$emotion_disgust' in fpadd_symbol_sequences[idx]:
|
||||
if idx + 1 < length and 'none' in fpadd_symbol_sequences[
|
||||
idx + 1]:
|
||||
idx = idx + 2
|
||||
else:
|
||||
idx = idx + 1
|
||||
continue
|
||||
out_str = out_str + symbol_sequences[idx] + ' '
|
||||
idx = idx + 1
|
||||
|
||||
if not error_flag:
|
||||
f_out.write(out_str.strip() + '\n')
|
||||
f_out.close()
|
||||
|
||||
def process(self, voice_output_dir, prosody, raw_metafile):
|
||||
|
||||
with open(raw_metafile, 'r') as f:
|
||||
lines = f.readlines()
|
||||
random.shuffle(lines)
|
||||
|
||||
fpadd_metafile = self.addfp(voice_output_dir, prosody, lines)
|
||||
self.removefp(voice_output_dir, fpadd_metafile, lines)
|
||||
@@ -0,0 +1,46 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
languages = {
|
||||
'PinYin': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'ZhHK': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'WuuShanghai': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'Sichuan': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': 'En2ChPhoneMap.txt',
|
||||
's2p_map_path': 'py2phoneMap.txt',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'EnGB': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': '',
|
||||
's2p_map_path': '',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
},
|
||||
'EnUS': {
|
||||
'phoneset_path': 'PhoneSet.xml',
|
||||
'posset_path': 'PosSet.xml',
|
||||
'f2t_map_path': '',
|
||||
's2p_map_path': '',
|
||||
'tonelist_path': 'tonelist.txt',
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,242 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class Tone(Enum):
|
||||
UnAssigned = -1
|
||||
NoneTone = 0
|
||||
YinPing = 1 # ZhHK: YinPingYinRu EnUS: primary stress
|
||||
YangPing = 2 # ZhHK: YinShang EnUS: secondary stress
|
||||
ShangSheng = 3 # ZhHK: YinQuZhongRu
|
||||
QuSheng = 4 # ZhHK: YangPing
|
||||
QingSheng = 5 # ZhHK: YangShang
|
||||
YangQuYangRu = 6 # ZhHK: YangQuYangRu
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(Tone, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['UnAssigned', '-1']:
|
||||
return Tone.UnAssigned
|
||||
elif in_str in ['NoneTone', '0']:
|
||||
return Tone.NoneTone
|
||||
elif in_str in ['YinPing', '1']:
|
||||
return Tone.YinPing
|
||||
elif in_str in ['YangPing', '2']:
|
||||
return Tone.YangPing
|
||||
elif in_str in ['ShangSheng', '3']:
|
||||
return Tone.ShangSheng
|
||||
elif in_str in ['QuSheng', '4']:
|
||||
return Tone.QuSheng
|
||||
elif in_str in ['QingSheng', '5']:
|
||||
return Tone.QingSheng
|
||||
elif in_str in ['YangQuYangRu', '6']:
|
||||
return Tone.YangQuYangRu
|
||||
else:
|
||||
return Tone.NoneTone
|
||||
|
||||
|
||||
class BreakLevel(Enum):
|
||||
UnAssigned = -1
|
||||
L0 = 0
|
||||
L1 = 1
|
||||
L2 = 2
|
||||
L3 = 3
|
||||
L4 = 4
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(BreakLevel, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['UnAssigned', '-1']:
|
||||
return BreakLevel.UnAssigned
|
||||
elif in_str in ['L0', '0']:
|
||||
return BreakLevel.L0
|
||||
elif in_str in ['L1', '1']:
|
||||
return BreakLevel.L1
|
||||
elif in_str in ['L2', '2']:
|
||||
return BreakLevel.L2
|
||||
elif in_str in ['L3', '3']:
|
||||
return BreakLevel.L3
|
||||
elif in_str in ['L4', '4']:
|
||||
return BreakLevel.L4
|
||||
else:
|
||||
return BreakLevel.UnAssigned
|
||||
|
||||
|
||||
class SentencePurpose(Enum):
|
||||
Declarative = 0
|
||||
Interrogative = 1
|
||||
Exclamatory = 2
|
||||
Imperative = 3
|
||||
|
||||
|
||||
class Language(Enum):
|
||||
Neutral = 0
|
||||
EnUS = 1033
|
||||
EnGB = 2057
|
||||
ZhCN = 2052
|
||||
PinYin = 2053
|
||||
WuuShanghai = 2054
|
||||
Sichuan = 2055
|
||||
ZhHK = 3076
|
||||
ZhEn = ZhCN | EnUS
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(Language, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['Neutral', '0']:
|
||||
return Language.Neutral
|
||||
elif in_str in ['EnUS', '1033']:
|
||||
return Language.EnUS
|
||||
elif in_str in ['EnGB', '2057']:
|
||||
return Language.EnGB
|
||||
elif in_str in ['ZhCN', '2052']:
|
||||
return Language.ZhCN
|
||||
elif in_str in ['PinYin', '2053']:
|
||||
return Language.PinYin
|
||||
elif in_str in ['WuuShanghai', '2054']:
|
||||
return Language.WuuShanghai
|
||||
elif in_str in ['Sichuan', '2055']:
|
||||
return Language.Sichuan
|
||||
elif in_str in ['ZhHK', '3076']:
|
||||
return Language.ZhHK
|
||||
elif in_str in ['ZhEn', '2052|1033']:
|
||||
return Language.ZhEn
|
||||
else:
|
||||
return Language.Neutral
|
||||
|
||||
|
||||
"""
|
||||
Phone Types
|
||||
"""
|
||||
|
||||
|
||||
class PhoneCVType(Enum):
|
||||
NULL = -1
|
||||
Consonant = 1
|
||||
Vowel = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneCVType, cls).__new__(cls, in_str)
|
||||
|
||||
if in_str in ['consonant', 'Consonant']:
|
||||
return PhoneCVType.Consonant
|
||||
elif in_str in ['vowel', 'Vowel']:
|
||||
return PhoneCVType.Vowel
|
||||
else:
|
||||
return PhoneCVType.NULL
|
||||
|
||||
|
||||
class PhoneIFType(Enum):
|
||||
NULL = -1
|
||||
Initial = 1
|
||||
Final = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneIFType, cls).__new__(cls, in_str)
|
||||
if in_str in ['initial', 'Initial']:
|
||||
return PhoneIFType.Initial
|
||||
elif in_str in ['final', 'Final']:
|
||||
return PhoneIFType.Final
|
||||
else:
|
||||
return PhoneIFType.NULL
|
||||
|
||||
|
||||
class PhoneUVType(Enum):
|
||||
NULL = -1
|
||||
Voiced = 1
|
||||
UnVoiced = 2
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneUVType, cls).__new__(cls, in_str)
|
||||
if in_str in ['voiced', 'Voiced']:
|
||||
return PhoneUVType.Voiced
|
||||
elif in_str in ['unvoiced', 'UnVoiced']:
|
||||
return PhoneUVType.UnVoiced
|
||||
else:
|
||||
return PhoneUVType.NULL
|
||||
|
||||
|
||||
class PhoneAPType(Enum):
|
||||
NULL = -1
|
||||
DoubleLips = 1
|
||||
LipTooth = 2
|
||||
FrontTongue = 3
|
||||
CentralTongue = 4
|
||||
BackTongue = 5
|
||||
Dorsal = 6
|
||||
Velar = 7
|
||||
Low = 8
|
||||
Middle = 9
|
||||
High = 10
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneAPType, cls).__new__(cls, in_str)
|
||||
if in_str in ['doublelips', 'DoubleLips']:
|
||||
return PhoneAPType.DoubleLips
|
||||
elif in_str in ['liptooth', 'LipTooth']:
|
||||
return PhoneAPType.LipTooth
|
||||
elif in_str in ['fronttongue', 'FrontTongue']:
|
||||
return PhoneAPType.FrontTongue
|
||||
elif in_str in ['centraltongue', 'CentralTongue']:
|
||||
return PhoneAPType.CentralTongue
|
||||
elif in_str in ['backtongue', 'BackTongue']:
|
||||
return PhoneAPType.BackTongue
|
||||
elif in_str in ['dorsal', 'Dorsal']:
|
||||
return PhoneAPType.Dorsal
|
||||
elif in_str in ['velar', 'Velar']:
|
||||
return PhoneAPType.Velar
|
||||
elif in_str in ['low', 'Low']:
|
||||
return PhoneAPType.Low
|
||||
elif in_str in ['middle', 'Middle']:
|
||||
return PhoneAPType.Middle
|
||||
elif in_str in ['high', 'High']:
|
||||
return PhoneAPType.High
|
||||
else:
|
||||
return PhoneAPType.NULL
|
||||
|
||||
|
||||
class PhoneAMType(Enum):
|
||||
NULL = -1
|
||||
Stop = 1
|
||||
Affricate = 2
|
||||
Fricative = 3
|
||||
Nasal = 4
|
||||
Lateral = 5
|
||||
Open = 6
|
||||
Close = 7
|
||||
|
||||
@classmethod
|
||||
def parse(cls, in_str):
|
||||
if not isinstance(in_str, str):
|
||||
return super(PhoneAMType, cls).__new__(cls, in_str)
|
||||
if in_str in ['stop', 'Stop']:
|
||||
return PhoneAMType.Stop
|
||||
elif in_str in ['affricate', 'Affricate']:
|
||||
return PhoneAMType.Affricate
|
||||
elif in_str in ['fricative', 'Fricative']:
|
||||
return PhoneAMType.Fricative
|
||||
elif in_str in ['nasal', 'Nasal']:
|
||||
return PhoneAMType.Nasal
|
||||
elif in_str in ['lateral', 'Lateral']:
|
||||
return PhoneAMType.Lateral
|
||||
elif in_str in ['open', 'Open']:
|
||||
return PhoneAMType.Open
|
||||
elif in_str in ['close', 'Close']:
|
||||
return PhoneAMType.Close
|
||||
else:
|
||||
return PhoneAMType.NULL
|
||||
@@ -0,0 +1,48 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .core_types import (PhoneAMType, PhoneAPType, PhoneCVType, PhoneIFType,
|
||||
PhoneUVType)
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Phone(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_id = None
|
||||
self.m_name = None
|
||||
self.m_cv_type = PhoneCVType.NULL
|
||||
self.m_if_type = PhoneIFType.NULL
|
||||
self.m_uv_type = PhoneUVType.NULL
|
||||
self.m_ap_type = PhoneAPType.NULL
|
||||
self.m_am_type = PhoneAMType.NULL
|
||||
self.m_bnd = False
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self, phone_node):
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
id_node = phone_node.find(ns + 'id')
|
||||
self.m_id = int(id_node.text)
|
||||
|
||||
name_node = phone_node.find(ns + 'name')
|
||||
self.m_name = name_node.text
|
||||
|
||||
cv_node = phone_node.find(ns + 'cv')
|
||||
self.m_cv_type = PhoneCVType.parse(cv_node.text)
|
||||
|
||||
if_node = phone_node.find(ns + 'if')
|
||||
self.m_if_type = PhoneIFType.parse(if_node.text)
|
||||
|
||||
uv_node = phone_node.find(ns + 'uv')
|
||||
self.m_uv_type = PhoneUVType.parse(uv_node.text)
|
||||
|
||||
ap_node = phone_node.find(ns + 'ap')
|
||||
self.m_ap_type = PhoneAPType.parse(ap_node.text)
|
||||
|
||||
am_node = phone_node.find(ns + 'am')
|
||||
self.m_am_type = PhoneAMType.parse(am_node.text)
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .phone import Phone
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class PhoneSet(XmlObj):
|
||||
|
||||
def __init__(self, phoneset_path):
|
||||
self.m_phone_list = []
|
||||
self.m_id_map = {}
|
||||
self.m_name_map = {}
|
||||
self.load(phoneset_path)
|
||||
|
||||
def load(self, file_path):
|
||||
# alibaba tts xml namespace
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
phoneset_root = ET.parse(file_path).getroot()
|
||||
for phone_node in phoneset_root.findall(ns + 'phone'):
|
||||
phone = Phone()
|
||||
phone.load(phone_node)
|
||||
self.m_phone_list.append(phone)
|
||||
if phone.m_id in self.m_id_map:
|
||||
logging.error('PhoneSet.Load: duplicate id: %d', phone.m_id)
|
||||
self.m_id_map[phone.m_id] = phone
|
||||
|
||||
if phone.m_name in self.m_name_map:
|
||||
logging.error('PhoneSet.Load duplicate name name: %s',
|
||||
phone.m_name)
|
||||
self.m_name_map[phone.m_name] = phone
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Pos(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_id = None
|
||||
self.m_name = None
|
||||
self.m_desc = None
|
||||
self.m_level = 1
|
||||
self.m_parent = None
|
||||
self.m_sub_pos_list = []
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self, pos_node):
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
id_node = pos_node.find(ns + 'id')
|
||||
self.m_id = int(id_node.text)
|
||||
|
||||
name_node = pos_node.find(ns + 'name')
|
||||
self.m_name = name_node.text
|
||||
|
||||
desc_node = pos_node.find(ns + 'desc')
|
||||
self.m_desc = desc_node.text
|
||||
|
||||
sub_node = pos_node.find(ns + 'sub')
|
||||
if sub_node is not None:
|
||||
for sub_pos_node in sub_node.findall(ns + 'pos'):
|
||||
sub_pos = Pos()
|
||||
sub_pos.load(sub_pos_node)
|
||||
sub_pos.m_parent = self
|
||||
sub_pos.m_level = self.m_level + 1
|
||||
self.m_sub_pos_list.append(sub_pos)
|
||||
|
||||
return
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .pos import Pos
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class PosSet(XmlObj):
|
||||
|
||||
def __init__(self, posset_path):
|
||||
self.m_pos_list = []
|
||||
self.m_id_map = {}
|
||||
self.m_name_map = {}
|
||||
self.load(posset_path)
|
||||
|
||||
def load(self, file_path):
|
||||
# alibaba tts xml namespace
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
posset_root = ET.parse(file_path).getroot()
|
||||
for pos_node in posset_root.findall(ns + 'pos'):
|
||||
pos = Pos()
|
||||
pos.load(pos_node)
|
||||
self.m_pos_list.append(pos)
|
||||
if pos.m_id in self.m_id_map:
|
||||
logging.error('PosSet.Load: duplicate id: %d', pos.m_id)
|
||||
self.m_id_map[pos.m_id] = pos
|
||||
|
||||
if pos.m_name in self.m_name_map:
|
||||
logging.error('PosSet.Load duplicate name name: %s',
|
||||
pos.m_name)
|
||||
self.m_name_map[pos.m_name] = pos
|
||||
|
||||
if len(pos.m_sub_pos_list) > 0:
|
||||
for sub_pos in pos.m_sub_pos_list:
|
||||
self.m_pos_list.append(sub_pos)
|
||||
if sub_pos.m_id in self.m_id_map:
|
||||
logging.error('PosSet.Load: duplicate id: %d',
|
||||
sub_pos.m_id)
|
||||
self.m_id_map[sub_pos.m_id] = sub_pos
|
||||
|
||||
if sub_pos.m_name in self.m_name_map:
|
||||
logging.error('PosSet.Load duplicate name name: %s',
|
||||
sub_pos.m_name)
|
||||
self.m_name_map[sub_pos.m_name] = sub_pos
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -0,0 +1,35 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
from xml.dom import minidom
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Script(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
self.m_items = []
|
||||
|
||||
def save(self, outputXMLPath):
|
||||
root = ET.Element('script')
|
||||
|
||||
root.set('uttcount', str(len(self.m_items)))
|
||||
root.set('xmlns', 'http://schemas.alibaba-inc.com/tts')
|
||||
for item in self.m_items:
|
||||
item.save(root)
|
||||
|
||||
xmlstr = minidom.parseString(ET.tostring(root)).toprettyxml(
|
||||
indent=' ', encoding='utf-8')
|
||||
with open(outputXMLPath, 'wb') as f:
|
||||
f.write(xmlstr)
|
||||
|
||||
def save_meta_file(self):
|
||||
meta_lines = []
|
||||
|
||||
for item in self.m_items:
|
||||
meta_lines.append(item.save_metafile())
|
||||
|
||||
return meta_lines
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class ScriptItem(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
if phoneset is None or posset is None:
|
||||
raise Exception('ScriptItem.__init__: phoneset or posset is None')
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
|
||||
self.m_id = None
|
||||
self.m_text = ''
|
||||
self.m_scriptSentence_list = []
|
||||
self.m_status = None
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self, parent_node):
|
||||
utterance_node = ET.SubElement(parent_node, 'utterance')
|
||||
utterance_node.set('id', self.m_id)
|
||||
|
||||
text_node = ET.SubElement(utterance_node, 'text')
|
||||
text_node.text = self.m_text
|
||||
|
||||
for sentence in self.m_scriptSentence_list:
|
||||
sentence.save(utterance_node)
|
||||
|
||||
def save_metafile(self):
|
||||
meta_line = self.m_id + '\t'
|
||||
|
||||
for sentence in self.m_scriptSentence_list:
|
||||
meta_line += sentence.save_metafile()
|
||||
|
||||
return meta_line
|
||||
@@ -0,0 +1,185 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class WrittenSentence(XmlObj):
|
||||
|
||||
def __init__(self, posset):
|
||||
self.m_written_word_list = []
|
||||
self.m_written_mark_list = []
|
||||
self.m_posset = posset
|
||||
self.m_align_list = []
|
||||
self.m_alignCursor = 0
|
||||
self.m_accompanyIndex = 0
|
||||
self.m_sequence = ''
|
||||
self.m_text = ''
|
||||
|
||||
def add_host(self, writtenWord):
|
||||
self.m_written_word_list.append(writtenWord)
|
||||
self.m_align_list.append(self.m_alignCursor)
|
||||
|
||||
def load_host(self):
|
||||
pass
|
||||
|
||||
def save_host(self):
|
||||
pass
|
||||
|
||||
def add_accompany(self, writtenMark):
|
||||
self.m_written_mark_list.append(writtenMark)
|
||||
self.m_alignCursor += 1
|
||||
self.m_accompanyIndex += 1
|
||||
|
||||
def save_accompany(self):
|
||||
pass
|
||||
|
||||
def load_accompany(self):
|
||||
pass
|
||||
|
||||
# Get the mark span corresponding to specific spoken word
|
||||
def get_accompany_span(self, host_index):
|
||||
if host_index == -1:
|
||||
return (0, self.m_align_list[0])
|
||||
|
||||
accompany_begin = self.m_align_list[host_index]
|
||||
accompany_end = (
|
||||
self.m_align_list[host_index + 1]
|
||||
if host_index + 1 < len(self.m_written_word_list) else len(
|
||||
self.m_written_mark_list))
|
||||
|
||||
return (accompany_begin, accompany_end)
|
||||
|
||||
def get_elements(self):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(-1)
|
||||
res_lst = [
|
||||
self.m_written_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
]
|
||||
|
||||
for j in range(len(self.m_written_word_list)):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(j)
|
||||
res_lst.extend([self.m_written_word_list[j]])
|
||||
res_lst.extend([
|
||||
self.m_written_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
])
|
||||
|
||||
return res_lst
|
||||
|
||||
def build_sequence(self):
|
||||
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def build_text(self):
|
||||
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
|
||||
class SpokenSentence(XmlObj):
|
||||
|
||||
def __init__(self, phoneset):
|
||||
self.m_spoken_word_list = []
|
||||
self.m_spoken_mark_list = []
|
||||
self.m_phoneset = phoneset
|
||||
self.m_align_list = []
|
||||
self.m_alignCursor = 0
|
||||
self.m_accompanyIndex = 0
|
||||
self.m_sequence = ''
|
||||
self.m_text = ''
|
||||
|
||||
def __len__(self):
|
||||
return len(self.m_spoken_word_list)
|
||||
|
||||
def add_host(self, spokenWord):
|
||||
self.m_spoken_word_list.append(spokenWord)
|
||||
self.m_align_list.append(self.m_alignCursor)
|
||||
|
||||
def save_host(self):
|
||||
pass
|
||||
|
||||
def load_host(self):
|
||||
pass
|
||||
|
||||
def add_accompany(self, spokenMark):
|
||||
self.m_spoken_mark_list.append(spokenMark)
|
||||
self.m_alignCursor += 1
|
||||
self.m_accompanyIndex += 1
|
||||
|
||||
def save_accompany(self):
|
||||
pass
|
||||
|
||||
# Get the mark span corresponding to specific spoken word
|
||||
def get_accompany_span(self, host_index):
|
||||
if host_index == -1:
|
||||
return (0, self.m_align_list[0])
|
||||
|
||||
accompany_begin = self.m_align_list[host_index]
|
||||
accompany_end = (
|
||||
self.m_align_list[host_index + 1]
|
||||
if host_index + 1 < len(self.m_spoken_word_list) else len(
|
||||
self.m_spoken_mark_list))
|
||||
|
||||
return (accompany_begin, accompany_end)
|
||||
|
||||
def get_elements(self):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(-1)
|
||||
res_lst = [
|
||||
self.m_spoken_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
]
|
||||
|
||||
for j in range(len(self.m_spoken_word_list)):
|
||||
accompany_begin, accompany_end = self.get_accompany_span(j)
|
||||
res_lst.extend([self.m_spoken_word_list[j]])
|
||||
res_lst.extend([
|
||||
self.m_spoken_mark_list[i]
|
||||
for i in range(accompany_begin, accompany_end)
|
||||
])
|
||||
|
||||
return res_lst
|
||||
|
||||
def load_accompany(self):
|
||||
pass
|
||||
|
||||
def build_sequence(self):
|
||||
self.m_sequence = ' '.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def build_text(self):
|
||||
self.m_text = ''.join([str(ele) for ele in self.get_elements()])
|
||||
|
||||
def save(self, parent_node):
|
||||
spoken_node = ET.SubElement(parent_node, 'spoken')
|
||||
spoken_node.set('wordcount', str(len(self.m_spoken_word_list)))
|
||||
|
||||
text_node = ET.SubElement(spoken_node, 'text')
|
||||
text_node.text = self.m_sequence
|
||||
|
||||
for word in self.m_spoken_word_list:
|
||||
word.save(spoken_node)
|
||||
|
||||
def save_metafile(self):
|
||||
meta_line_list = [
|
||||
word.save_metafile() for word in self.m_spoken_word_list
|
||||
]
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class ScriptSentence(XmlObj):
|
||||
|
||||
def __init__(self, phoneset, posset):
|
||||
self.m_phoneset = phoneset
|
||||
self.m_posset = posset
|
||||
self.m_writtenSentence = WrittenSentence(posset)
|
||||
self.m_spokenSentence = SpokenSentence(phoneset)
|
||||
self.m_text = ''
|
||||
|
||||
def save(self, parent_node):
|
||||
if len(self.m_spokenSentence) > 0:
|
||||
self.m_spokenSentence.save(parent_node)
|
||||
|
||||
def save_metafile(self):
|
||||
if len(self.m_spokenSentence) > 0:
|
||||
return self.m_spokenSentence.save_metafile()
|
||||
else:
|
||||
return ''
|
||||
@@ -0,0 +1,120 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .core_types import Language
|
||||
from .syllable import SyllableList
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class WrittenWord(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_name = None
|
||||
self.m_POS = None
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
|
||||
class WrittenMark(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_punctuation = None
|
||||
|
||||
def __str__(self):
|
||||
return self.m_punctuation
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
|
||||
class SpokenWord(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_name = None
|
||||
self.m_language = None
|
||||
self.m_syllable_list = []
|
||||
self.m_breakText = '1'
|
||||
self.m_POS = '0'
|
||||
|
||||
def __str__(self):
|
||||
return self.m_name
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self, parent_node):
|
||||
|
||||
word_node = ET.SubElement(parent_node, 'word')
|
||||
|
||||
name_node = ET.SubElement(word_node, 'name')
|
||||
name_node.text = self.m_name
|
||||
|
||||
if (len(self.m_syllable_list) > 0
|
||||
and self.m_syllable_list[0].m_language != Language.Neutral):
|
||||
language_node = ET.SubElement(word_node, 'lang')
|
||||
language_node.text = self.m_syllable_list[0].m_language.name
|
||||
|
||||
SyllableList(self.m_syllable_list).save(word_node)
|
||||
|
||||
break_node = ET.SubElement(word_node, 'break')
|
||||
break_node.text = self.m_breakText
|
||||
|
||||
POS_node = ET.SubElement(word_node, 'POS')
|
||||
POS_node.text = self.m_POS
|
||||
|
||||
return
|
||||
|
||||
def save_metafile(self):
|
||||
word_phone_cnt = sum(
|
||||
[syllable.phone_count() for syllable in self.m_syllable_list])
|
||||
word_syllable_cnt = len(self.m_syllable_list)
|
||||
single_syllable_word = word_syllable_cnt == 1
|
||||
meta_line_list = []
|
||||
|
||||
for idx, syll in enumerate(self.m_syllable_list):
|
||||
if word_phone_cnt == 1:
|
||||
word_pos = 'word_both'
|
||||
elif idx == 0:
|
||||
word_pos = 'word_begin'
|
||||
elif idx == len(self.m_syllable_list) - 1:
|
||||
word_pos = 'word_end'
|
||||
else:
|
||||
word_pos = 'word_middle'
|
||||
meta_line_list.append(
|
||||
syll.save_metafile(
|
||||
word_pos, single_syllable_word=single_syllable_word))
|
||||
|
||||
if self.m_breakText != '0' and self.m_breakText is not None:
|
||||
meta_line_list.append('{{#{}$tone_none$s_none$word_none}}'.format(
|
||||
self.m_breakText))
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class SpokenMark(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_breakLevel = None
|
||||
|
||||
def break_level2text(self):
|
||||
return '#' + str(self.m_breakLevel.value)
|
||||
|
||||
def __str__(self):
|
||||
return self.break_level2text()
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from .xml_obj import XmlObj
|
||||
|
||||
|
||||
class Syllable(XmlObj):
|
||||
|
||||
def __init__(self):
|
||||
self.m_phone_list = []
|
||||
self.m_tone = None
|
||||
self.m_language = None
|
||||
self.m_breaklevel = None
|
||||
|
||||
def pronunciation_text(self):
|
||||
return ' '.join([str(phone) for phone in self.m_phone_list])
|
||||
|
||||
def phone_count(self):
|
||||
return len(self.m_phone_list)
|
||||
|
||||
def tone_text(self):
|
||||
return str(self.m_tone.value)
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def get_phone_meta(self,
|
||||
phone_name,
|
||||
word_pos,
|
||||
syll_pos,
|
||||
tone_text,
|
||||
single_syllable_word=False):
|
||||
# Special case: word with single syllable, the last phone's word_pos should be "word_end"
|
||||
if word_pos == 'word_begin' and syll_pos == 's_end' and single_syllable_word:
|
||||
word_pos = 'word_end'
|
||||
elif word_pos == 'word_begin' and syll_pos not in [
|
||||
's_begin',
|
||||
's_both',
|
||||
]: # FIXME: keep accord with Engine logic
|
||||
word_pos = 'word_middle'
|
||||
elif word_pos == 'word_end' and syll_pos not in ['s_end', 's_both']:
|
||||
word_pos = 'word_middle'
|
||||
else:
|
||||
pass
|
||||
|
||||
return '{{{}$tone{}${}${}}}'.format(phone_name, tone_text, syll_pos,
|
||||
word_pos)
|
||||
|
||||
def save_metafile(self, word_pos, single_syllable_word=False):
|
||||
syllable_phone_cnt = len(self.m_phone_list)
|
||||
|
||||
meta_line_list = []
|
||||
|
||||
for idx, phone in enumerate(self.m_phone_list):
|
||||
if syllable_phone_cnt == 1:
|
||||
syll_pos = 's_both'
|
||||
elif idx == 0:
|
||||
syll_pos = 's_begin'
|
||||
elif idx == len(self.m_phone_list) - 1:
|
||||
syll_pos = 's_end'
|
||||
else:
|
||||
syll_pos = 's_middle'
|
||||
meta_line_list.append(
|
||||
self.get_phone_meta(
|
||||
phone,
|
||||
word_pos,
|
||||
syll_pos,
|
||||
self.tone_text(),
|
||||
single_syllable_word=single_syllable_word,
|
||||
))
|
||||
|
||||
return ' '.join(meta_line_list)
|
||||
|
||||
|
||||
class SyllableList(XmlObj):
|
||||
|
||||
def __init__(self, syllables):
|
||||
self.m_syllable_list = syllables
|
||||
|
||||
def __len__(self):
|
||||
return len(self.m_syllable_list)
|
||||
|
||||
def __index__(self, index):
|
||||
return self.m_syllable_list[index]
|
||||
|
||||
def pronunciation_text(self):
|
||||
return ' - '.join([
|
||||
syllable.pronunciation_text() for syllable in self.m_syllable_list
|
||||
])
|
||||
|
||||
def tone_text(self):
|
||||
return ''.join(
|
||||
[syllable.tone_text() for syllable in self.m_syllable_list])
|
||||
|
||||
def save(self, parent_node):
|
||||
syllable_node = ET.SubElement(parent_node, 'syllable')
|
||||
syllable_node.set('syllcount', str(len(self.m_syllable_list)))
|
||||
|
||||
phone_node = ET.SubElement(syllable_node, 'phone')
|
||||
phone_node.text = self.pronunciation_text()
|
||||
|
||||
tone_node = ET.SubElement(syllable_node, 'tone')
|
||||
tone_node.text = self.tone_text()
|
||||
|
||||
return
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
@@ -0,0 +1,322 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import re
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core_types import Language, PhoneCVType, Tone
|
||||
from .syllable import Syllable
|
||||
from .utils import NgBreakPattern
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class DefaultSyllableFormatter:
|
||||
|
||||
def __init__(self):
|
||||
return
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
logging.warning('Using DefaultSyllableFormatter dry run: %s', pronText)
|
||||
return True
|
||||
|
||||
|
||||
RegexNg2en = re.compile(NgBreakPattern)
|
||||
RegexQingSheng = re.compile(r'([1-5]5)')
|
||||
RegexPron = re.compile(r'(?P<Pron>[a-z]+)(?P<Tone>[1-6])')
|
||||
|
||||
|
||||
class ZhCNSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
# Replace Qing Sheng
|
||||
newPron = pronText.replace('6', '2')
|
||||
newPron = re.sub(RegexQingSheng, '5', newPron)
|
||||
|
||||
# FIXME(Jin): ng case overrides newPron
|
||||
match = RegexNg2en.search(newPron)
|
||||
if match:
|
||||
newPron = 'en' + match.group('break')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('ZhCNSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
if pronText in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pronText].split(' ')
|
||||
if len(phone_list) == 3:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(
|
||||
pronText[-1]) # FIXME(Jin): assume tone is the last char
|
||||
syll.m_language = Language.ZhCN
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'ZhCNSyllableFormatter.Format: invalid pronText: %s',
|
||||
pronText)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'ZhCNSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class PinYinSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
newPron = pronText.replace('6', '2')
|
||||
newPron = re.sub(RegexQingSheng, '5', newPron)
|
||||
|
||||
# FIXME(Jin): ng case overrides newPron
|
||||
match = RegexNg2en.search(newPron)
|
||||
if match:
|
||||
newPron = 'en' + match.group('break')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('PinYinSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.PinYin
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'PinYinSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class ZhHKSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('ZhHKSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.ZhHK
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'ZhHKSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class WuuShanghaiSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('WuuShanghaiSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.WuuShanghai
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: invalid phone: %s',
|
||||
pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'WuuShanghaiSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class SichuanSyllableFormatter:
|
||||
|
||||
def __init__(self, sy2ph_map):
|
||||
self.m_sy2ph_map = sy2ph_map
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('SichuanSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
|
||||
match = RegexPron.search(pronText)
|
||||
if match:
|
||||
pron = match.group('Pron')
|
||||
tone = match.group('Tone')
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: pronunciation is not valid: %s',
|
||||
pronText,
|
||||
)
|
||||
return False
|
||||
|
||||
if pron in self.m_sy2ph_map:
|
||||
phone_list = self.m_sy2ph_map[pron].split(' ')
|
||||
if len(phone_list) in [1, 2]:
|
||||
syll = Syllable()
|
||||
for phone in phone_list:
|
||||
syll.m_phone_list.append(phone)
|
||||
syll.m_tone = Tone.parse(tone)
|
||||
syll.m_language = Language.Sichuan
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: invalid phone: %s', pron)
|
||||
return False
|
||||
else:
|
||||
logging.error(
|
||||
'SichuanSyllableFormatter.Format: syllable to phone map missing key: %s',
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
class EnXXSyllableFormatter:
|
||||
|
||||
def __init__(self, language):
|
||||
self.m_f2t_map = None
|
||||
self.m_language = language
|
||||
|
||||
def normalize_pron(self, pronText):
|
||||
newPron = pronText.replace('#', '.')
|
||||
newPron = (
|
||||
newPron.replace('03',
|
||||
'0').replace('13',
|
||||
'1').replace('23',
|
||||
'2').replace('3', ''))
|
||||
newPron = newPron.replace('2', '0')
|
||||
|
||||
return newPron
|
||||
|
||||
def format(self, phoneset, pronText, syllable_list):
|
||||
if phoneset is None or syllable_list is None or pronText is None:
|
||||
logging.error('EnXXSyllableFormatter.Format: invalid input')
|
||||
return False
|
||||
pronText = self.normalize_pron(pronText)
|
||||
|
||||
syllables = [ele.strip() for ele in pronText.split('.')]
|
||||
|
||||
for i in range(len(syllables)):
|
||||
syll = Syllable()
|
||||
syll.m_language = self.m_language
|
||||
syll.m_tone = Tone.parse('0')
|
||||
|
||||
phones = re.split(r'[\s]+', syllables[i])
|
||||
|
||||
for j in range(len(phones)):
|
||||
phoneName = phones[j].lower()
|
||||
toneName = '0'
|
||||
|
||||
if '0' in phoneName or '1' in phoneName or '2' in phoneName:
|
||||
toneName = phoneName[-1]
|
||||
phoneName = phoneName[:-1]
|
||||
|
||||
phoneName_lst = None
|
||||
if self.m_f2t_map is not None:
|
||||
phoneName_lst = self.m_f2t_map.get(phoneName, None)
|
||||
if phoneName_lst is None:
|
||||
phoneName_lst = [phoneName]
|
||||
|
||||
for new_phoneName in phoneName_lst:
|
||||
phone_obj = phoneset.m_name_map.get(new_phoneName, None)
|
||||
if phone_obj is None:
|
||||
logging.error(
|
||||
'EnXXSyllableFormatter.Format: phone %s not found',
|
||||
new_phoneName,
|
||||
)
|
||||
return False
|
||||
phone_obj.m_name = new_phoneName
|
||||
syll.m_phone_list.append(phone_obj)
|
||||
if phone_obj.m_cv_type == PhoneCVType.Vowel:
|
||||
syll.m_tone = Tone.parse(toneName)
|
||||
|
||||
if j == len(phones) - 1:
|
||||
phone_obj.m_bnd = True
|
||||
syllable_list.append(syll)
|
||||
return True
|
||||
@@ -0,0 +1,112 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import codecs
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
WordPattern = r'((?P<Word>\w+)(\(\w+\))?)'
|
||||
BreakPattern = r'(?P<Break>(\*?#(?P<BreakLevel>[0-4])))'
|
||||
MarkPattern = r'(?P<Mark>[、,。!?:“”《》·])'
|
||||
POSPattern = r'(?P<POS>(\*?\|(?P<POSClass>[1-9])))'
|
||||
PhraseTonePattern = r'(?P<PhraseTone>(\*?%([L|H])))'
|
||||
|
||||
NgBreakPattern = r'^ng(?P<break>\d)'
|
||||
|
||||
RegexWord = re.compile(WordPattern + r'\s*')
|
||||
RegexBreak = re.compile(BreakPattern + r'\s*')
|
||||
RegexID = re.compile(r'^(?P<ID>[a-zA-Z\-_0-9\.]+)\s*')
|
||||
RegexSentence = re.compile(r'({}|{}|{}|{}|{})\s*'.format(
|
||||
WordPattern, BreakPattern, MarkPattern, POSPattern, PhraseTonePattern))
|
||||
RegexForeignLang = re.compile(r'[A-Z@]')
|
||||
RegexSpace = re.compile(r'^\s*')
|
||||
RegexNeutralTone = re.compile(r'[1-5]5')
|
||||
|
||||
|
||||
def do_character_normalization(line):
|
||||
return unicodedata.normalize('NFKC', line)
|
||||
|
||||
|
||||
def do_prosody_text_normalization(line):
|
||||
tokens = line.split('\t')
|
||||
text = tokens[1]
|
||||
# Remove punctuations
|
||||
text = text.replace(u'。', ' ')
|
||||
text = text.replace(u'、', ' ')
|
||||
text = text.replace(u'“', ' ')
|
||||
text = text.replace(u'”', ' ')
|
||||
text = text.replace(u'‘', ' ')
|
||||
text = text.replace(u'’', ' ')
|
||||
text = text.replace(u'|', ' ')
|
||||
text = text.replace(u'《', ' ')
|
||||
text = text.replace(u'》', ' ')
|
||||
text = text.replace(u'【', ' ')
|
||||
text = text.replace(u'】', ' ')
|
||||
text = text.replace(u'—', ' ')
|
||||
text = text.replace(u'―', ' ')
|
||||
text = text.replace('.', ' ')
|
||||
text = text.replace('!', ' ')
|
||||
text = text.replace('?', ' ')
|
||||
text = text.replace('(', ' ')
|
||||
text = text.replace(')', ' ')
|
||||
text = text.replace('[', ' ')
|
||||
text = text.replace(']', ' ')
|
||||
text = text.replace('{', ' ')
|
||||
text = text.replace('}', ' ')
|
||||
text = text.replace('~', ' ')
|
||||
text = text.replace(':', ' ')
|
||||
text = text.replace(';', ' ')
|
||||
text = text.replace('+', ' ')
|
||||
text = text.replace(',', ' ')
|
||||
# text = text.replace('·', ' ')
|
||||
text = text.replace('"', ' ')
|
||||
text = text.replace(
|
||||
'-',
|
||||
'') # don't replace by space because compond word like two-year-old
|
||||
text = text.replace(
|
||||
"'", '') # don't replace by space because English word like that's
|
||||
|
||||
# Replace break
|
||||
text = text.replace('/', '#2')
|
||||
text = text.replace('%', '#3')
|
||||
# Remove useless spaces surround #2 #3 #4
|
||||
text = re.sub(r'(#\d)[ ]+', r'\1', text)
|
||||
text = re.sub(r'[ ]+(#\d)', r'\1', text)
|
||||
# Replace space by #1
|
||||
text = re.sub('[ ]+', '#1', text)
|
||||
|
||||
# Remove break at the end of the text
|
||||
text = re.sub(r'#\d$', '', text)
|
||||
|
||||
# Add #1 between target language and foreign language
|
||||
text = re.sub(r"([a-zA-Z])([^a-zA-Z\d\#\s\'\%\/\-])", r'\1#1\2', text)
|
||||
text = re.sub(r"([^a-zA-Z\d\#\s\'\%\/\-])([a-zA-Z])", r'\1#1\2', text)
|
||||
|
||||
return tokens[0] + '\t' + text
|
||||
|
||||
|
||||
def is_fp_line(line):
|
||||
fp_category_list = ['FP', 'I', 'N', 'Q']
|
||||
elements = line.strip().split(' ')
|
||||
res = True
|
||||
for ele in elements:
|
||||
if ele not in fp_category_list:
|
||||
res = False
|
||||
break
|
||||
return res
|
||||
|
||||
|
||||
def format_prosody(src_prosody):
|
||||
formatted_lines = []
|
||||
with codecs.open(src_prosody, 'r', 'utf-8') as f:
|
||||
lines = f.readlines()
|
||||
fp_enable = is_fp_line(lines[1])
|
||||
|
||||
for i in range(0, len(lines)):
|
||||
line = do_character_normalization(lines[i])
|
||||
if fp_enable:
|
||||
if i % 5 == 1 or i % 5 == 2 or i % 5 == 3:
|
||||
continue
|
||||
if len(line.strip().split('\t')) == 2:
|
||||
line = do_prosody_text_normalization(line)
|
||||
formatted_lines.append(line)
|
||||
return formatted_lines
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
|
||||
class XmlObj:
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def load(self):
|
||||
pass
|
||||
|
||||
def save(self):
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
pass
|
||||
|
||||
def save_data(self):
|
||||
pass
|
||||
@@ -0,0 +1,463 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .core.core_types import BreakLevel, Language
|
||||
from .core.phone_set import PhoneSet
|
||||
from .core.pos_set import PosSet
|
||||
from .core.script import Script
|
||||
from .core.script_item import ScriptItem
|
||||
from .core.script_sentence import ScriptSentence
|
||||
from .core.script_word import SpokenMark, SpokenWord, WrittenMark, WrittenWord
|
||||
from .core.utils import (RegexForeignLang, RegexID, RegexSentence,
|
||||
format_prosody)
|
||||
|
||||
from .core.utils import RegexNeutralTone # isort:skip
|
||||
|
||||
from .core.syllable_formatter import ( # isort:skip
|
||||
EnXXSyllableFormatter, PinYinSyllableFormatter, # isort:skip
|
||||
SichuanSyllableFormatter, # isort:skip
|
||||
WuuShanghaiSyllableFormatter, ZhCNSyllableFormatter, # isort:skip
|
||||
ZhHKSyllableFormatter) # isort:skip
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
|
||||
class TextScriptConvertor:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
phoneset_path,
|
||||
posset_path,
|
||||
target_lang,
|
||||
foreign_lang,
|
||||
f2t_map_path,
|
||||
s2p_map_path,
|
||||
m_emo_tag_path,
|
||||
m_speaker,
|
||||
):
|
||||
self.m_f2p_map = {}
|
||||
self.m_s2p_map = {}
|
||||
self.m_phoneset = PhoneSet(phoneset_path)
|
||||
self.m_posset = PosSet(posset_path)
|
||||
self.m_target_lang = Language.parse(target_lang)
|
||||
self.m_foreign_lang = Language.parse(foreign_lang)
|
||||
self.m_emo_tag_path = m_emo_tag_path
|
||||
self.m_speaker = m_speaker
|
||||
|
||||
self.load_f2tmap(f2t_map_path)
|
||||
self.load_s2pmap(s2p_map_path)
|
||||
|
||||
self.m_target_lang_syllable_formatter = self.init_syllable_formatter(
|
||||
self.m_target_lang)
|
||||
self.m_foreign_lang_syllable_formatter = self.init_syllable_formatter(
|
||||
self.m_foreign_lang)
|
||||
|
||||
def parse_sentence(self, sentence, line_num):
|
||||
script_item = ScriptItem(self.m_phoneset, self.m_posset)
|
||||
script_sentence = ScriptSentence(self.m_phoneset, self.m_posset)
|
||||
script_item.m_scriptSentence_list.append(script_sentence)
|
||||
|
||||
written_sentence = script_sentence.m_writtenSentence
|
||||
spoken_sentence = script_sentence.m_spokenSentence
|
||||
|
||||
position = 0
|
||||
|
||||
sentence = sentence.strip()
|
||||
|
||||
# Get ID
|
||||
match = re.search(RegexID, sentence)
|
||||
if match is None:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:invalid line: %s,\
|
||||
line ID is needed',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
else:
|
||||
sentence_id = match.group('ID')
|
||||
script_item.m_id = sentence_id
|
||||
position += match.end()
|
||||
|
||||
prevSpokenWord = SpokenWord()
|
||||
|
||||
prevWord = False
|
||||
lastBreak = False
|
||||
|
||||
for m in re.finditer(RegexSentence, sentence[position:]):
|
||||
if m is None:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:\
|
||||
invalid line: %s, there is no matched pattern',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
|
||||
if m.group('Word') is not None:
|
||||
wordName = m.group('Word')
|
||||
written_word = WrittenWord()
|
||||
written_word.m_name = wordName
|
||||
written_sentence.add_host(written_word)
|
||||
|
||||
spoken_word = SpokenWord()
|
||||
spoken_word.m_name = wordName
|
||||
prevSpokenWord = spoken_word
|
||||
prevWord = True
|
||||
lastBreak = False
|
||||
elif m.group('Break') is not None:
|
||||
breakText = m.group('BreakLevel')
|
||||
if len(breakText) == 0:
|
||||
breakLevel = BreakLevel.L1
|
||||
else:
|
||||
breakLevel = BreakLevel.parse(breakText)
|
||||
if prevWord:
|
||||
prevSpokenWord.m_breakText = breakText
|
||||
spoken_sentence.add_host(prevSpokenWord)
|
||||
|
||||
if breakLevel != BreakLevel.L1:
|
||||
spokenMark = SpokenMark()
|
||||
spokenMark.m_breakLevel = breakLevel
|
||||
spoken_sentence.add_accompany(spokenMark)
|
||||
|
||||
lastBreak = True
|
||||
|
||||
elif m.group('PhraseTone') is not None:
|
||||
pass
|
||||
elif m.group('POS') is not None:
|
||||
POSClass = m.group('POSClass')
|
||||
if prevWord:
|
||||
prevSpokenWord.m_pos = POSClass
|
||||
prevWord = False
|
||||
elif m.group('Mark') is not None:
|
||||
markText = m.group('Mark')
|
||||
|
||||
writtenMark = WrittenMark()
|
||||
writtenMark.m_punctuation = markText
|
||||
written_sentence.add_accompany(writtenMark)
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_sentence:\
|
||||
invalid line: %s, matched pattern is unrecognized',
|
||||
line_num,
|
||||
)
|
||||
return None
|
||||
|
||||
if not lastBreak:
|
||||
prevSpokenWord.m_breakText = '4'
|
||||
spoken_sentence.add_host(prevSpokenWord)
|
||||
|
||||
spoken_word_cnt = len(spoken_sentence.m_spoken_word_list)
|
||||
spoken_mark_cnt = len(spoken_sentence.m_spoken_mark_list)
|
||||
if (spoken_word_cnt > 0
|
||||
and spoken_sentence.m_align_list[spoken_word_cnt - 1]
|
||||
== spoken_mark_cnt):
|
||||
spokenMark = SpokenMark()
|
||||
spokenMark.m_breakLevel = BreakLevel.L4
|
||||
spoken_sentence.add_accompany(spokenMark)
|
||||
|
||||
written_sentence.build_sequence()
|
||||
spoken_sentence.build_sequence()
|
||||
written_sentence.build_text()
|
||||
spoken_sentence.build_text()
|
||||
|
||||
script_sentence.m_text = written_sentence.m_text
|
||||
script_item.m_text = written_sentence.m_text
|
||||
|
||||
return script_item
|
||||
|
||||
def format_syllable(self, pron, syllable_list):
|
||||
isForeign = RegexForeignLang.search(pron) is not None
|
||||
if self.m_foreign_lang_syllable_formatter is not None and isForeign:
|
||||
return self.m_foreign_lang_syllable_formatter.format(
|
||||
self.m_phoneset, pron, syllable_list)
|
||||
else:
|
||||
return self.m_target_lang_syllable_formatter.format(
|
||||
self.m_phoneset, pron, syllable_list)
|
||||
|
||||
def get_word_prons(self, pronText):
|
||||
prons = pronText.split('/')
|
||||
res = []
|
||||
|
||||
for pron in prons:
|
||||
if re.search(RegexForeignLang, pron):
|
||||
res.append(pron.strip())
|
||||
else:
|
||||
res.extend(pron.strip().split(' '))
|
||||
return res
|
||||
|
||||
def is_erhuayin(self, pron):
|
||||
pron = RegexNeutralTone.sub('5', pron)
|
||||
pron = pron[:-1]
|
||||
|
||||
return pron[-1] == 'r' and pron != 'er'
|
||||
|
||||
def parse_pronunciation(self, script_item, pronunciation, line_num):
|
||||
spoken_sentence = script_item.m_scriptSentence_list[0].m_spokenSentence
|
||||
|
||||
wordProns = self.get_word_prons(pronunciation)
|
||||
|
||||
wordIndex = 0
|
||||
pronIndex = 0
|
||||
succeed = True
|
||||
|
||||
while pronIndex < len(wordProns):
|
||||
language = Language.Neutral
|
||||
syllable_list = []
|
||||
|
||||
pron = wordProns[pronIndex].strip()
|
||||
|
||||
succeed = self.format_syllable(pron, syllable_list)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation:\
|
||||
invalid line: %s, error pronunciation: %s,\
|
||||
syllable format error',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
language = syllable_list[0].m_language
|
||||
|
||||
if wordIndex < len(spoken_sentence.m_spoken_word_list):
|
||||
if language in [Language.EnGB, Language.EnUS]:
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
elif language in [
|
||||
Language.ZhCN,
|
||||
Language.PinYin,
|
||||
Language.ZhHK,
|
||||
Language.WuuShanghai,
|
||||
Language.Sichuan,
|
||||
]:
|
||||
charCount = len(
|
||||
spoken_sentence.m_spoken_word_list[wordIndex].m_name)
|
||||
if (language in [
|
||||
Language.ZhCN, Language.PinYin, Language.Sichuan
|
||||
] and self.is_erhuayin(pron) and '儿' in spoken_sentence.
|
||||
m_spoken_word_list[wordIndex].m_name):
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name.replace('儿', '')
|
||||
charCount -= 1
|
||||
if charCount == 1:
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
else:
|
||||
# FIXME(Jin): Just skip the first char then match the rest char.
|
||||
i = 1
|
||||
while i >= 1 and i < charCount:
|
||||
pronIndex += 1
|
||||
if pronIndex < len(wordProns):
|
||||
pron = wordProns[pronIndex].strip()
|
||||
succeed = self.format_syllable(
|
||||
pron, syllable_list)
|
||||
if not succeed:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, syllable format error',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
if (language in [
|
||||
Language.ZhCN,
|
||||
Language.PinYin,
|
||||
Language.Sichuan,
|
||||
] and self.is_erhuayin(pron)
|
||||
and '儿' in spoken_sentence.
|
||||
m_spoken_word_list[wordIndex].m_name):
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name = spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_name.replace('儿', '')
|
||||
charCount -= 1
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, Word count mismatch with Pron count',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
i += 1
|
||||
spoken_sentence.m_spoken_word_list[
|
||||
wordIndex].m_syllable_list.extend(syllable_list)
|
||||
wordIndex += 1
|
||||
pronIndex += 1
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
unsupported language: %s',
|
||||
line_num,
|
||||
language.name,
|
||||
)
|
||||
return False
|
||||
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, word index is out of range',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
if pronIndex != len(wordProns):
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, pron count mismatch with word count',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
if wordIndex != len(spoken_sentence.m_spoken_word_list):
|
||||
logging.error(
|
||||
'TextScriptConvertor.parse_pronunciation: invalid line: %s, \
|
||||
error pronunciation: %s, word count mismatch with word index',
|
||||
line_num,
|
||||
pron,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def load_f2tmap(self, file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
elements = line.split('\t')
|
||||
if len(elements) != 2:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadF2TMap: invalid line: %s',
|
||||
line)
|
||||
continue
|
||||
key = elements[0]
|
||||
value = elements[1]
|
||||
value_list = value.split(' ')
|
||||
if key in self.m_f2p_map:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadF2TMap: duplicate key: %s',
|
||||
key)
|
||||
self.m_f2p_map[key] = value_list
|
||||
|
||||
def load_s2pmap(self, file_path):
|
||||
with open(file_path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
elements = line.split('\t')
|
||||
if len(elements) != 2:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadS2PMap: invalid line: %s',
|
||||
line)
|
||||
continue
|
||||
key = elements[0]
|
||||
value = elements[1]
|
||||
if key in self.m_s2p_map:
|
||||
logging.error(
|
||||
'TextScriptConvertor.LoadS2PMap: duplicate key: %s',
|
||||
key)
|
||||
self.m_s2p_map[key] = value
|
||||
|
||||
def init_syllable_formatter(self, targetLang):
|
||||
if targetLang == Language.ZhCN:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: ZhCN syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return ZhCNSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.PinYin:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: PinYin syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return PinYinSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.ZhHK:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: ZhHK syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return ZhHKSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.WuuShanghai:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: WuuShanghai syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return WuuShanghaiSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.Sichuan:
|
||||
if len(self.m_s2p_map) == 0:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: Sichuan syllable to phone map is empty'
|
||||
)
|
||||
return None
|
||||
return SichuanSyllableFormatter(self.m_s2p_map)
|
||||
elif targetLang == Language.EnGB:
|
||||
formatter = EnXXSyllableFormatter(Language.EnGB)
|
||||
if len(self.m_f2p_map) != 0:
|
||||
formatter.m_f2t_map = self.m_f2p_map
|
||||
return formatter
|
||||
elif targetLang == Language.EnUS:
|
||||
formatter = EnXXSyllableFormatter(Language.EnUS)
|
||||
if len(self.m_f2p_map) != 0:
|
||||
formatter.m_f2t_map = self.m_f2p_map
|
||||
return formatter
|
||||
else:
|
||||
logging.error(
|
||||
'TextScriptConvertor.InitSyllableFormatter: unsupported language: %s',
|
||||
targetLang,
|
||||
)
|
||||
return None
|
||||
|
||||
def process(self, textScriptPath, outputXMLPath, outputMetafile):
|
||||
script = Script(self.m_phoneset, self.m_posset)
|
||||
formatted_lines = format_prosody(textScriptPath)
|
||||
line_num = 0
|
||||
for line in tqdm(formatted_lines):
|
||||
if line_num % 2 == 0:
|
||||
sentence = line.strip()
|
||||
item = self.parse_sentence(sentence, line_num)
|
||||
else:
|
||||
if item is not None:
|
||||
pronunciation = line.strip()
|
||||
res = self.parse_pronunciation(item, pronunciation,
|
||||
line_num)
|
||||
if res:
|
||||
script.m_items.append(item)
|
||||
|
||||
line_num += 1
|
||||
|
||||
script.save(outputXMLPath)
|
||||
logging.info('TextScriptConvertor.process:\nSave script to: %s',
|
||||
outputXMLPath)
|
||||
|
||||
meta_lines = script.save_meta_file()
|
||||
emo = 'emotion_neutral'
|
||||
speaker = self.m_speaker
|
||||
|
||||
meta_lines_tagged = []
|
||||
for line in meta_lines:
|
||||
line_id, line_text = line.split('\t')
|
||||
syll_items = line_text.split(' ')
|
||||
syll_items_tagged = []
|
||||
for syll_item in syll_items:
|
||||
syll_item_tagged = syll_item[:-1] + '$' + emo + '$' + speaker + '}'
|
||||
syll_items_tagged.append(syll_item_tagged)
|
||||
meta_lines_tagged.append(line_id + '\t'
|
||||
+ ' '.join(syll_items_tagged))
|
||||
with open(outputMetafile, 'w') as f:
|
||||
for line in meta_lines_tagged:
|
||||
f.write(line + '\n')
|
||||
|
||||
logging.info('TextScriptConvertor.process:\nSave metafile to: %s',
|
||||
outputMetafile)
|
||||
562
modelscope/models/audio/tts/kantts/train/loss.py
Normal file
562
modelscope/models/audio/tts/kantts/train/loss.py
Normal file
@@ -0,0 +1,562 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.audio.tts.kantts.models.utils import \
|
||||
get_mask_from_lengths
|
||||
from modelscope.models.audio.tts.kantts.utils.audio_torch import (
|
||||
MelSpectrogram, stft)
|
||||
|
||||
|
||||
class MelReconLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='mae'):
|
||||
super(MelReconLoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
if loss_type == 'mae':
|
||||
self.criterion = torch.nn.L1Loss(reduction='none')
|
||||
elif loss_type == 'mse':
|
||||
self.criterion = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise ValueError('Unknown loss type: {}'.format(loss_type))
|
||||
|
||||
def forward(self,
|
||||
output_lengths,
|
||||
mel_targets,
|
||||
dec_outputs,
|
||||
postnet_outputs=None):
|
||||
output_masks = get_mask_from_lengths(
|
||||
output_lengths, max_len=mel_targets.size(1))
|
||||
output_masks = ~output_masks
|
||||
valid_outputs = output_masks.sum()
|
||||
|
||||
mel_loss_ = torch.sum(
|
||||
self.criterion(mel_targets, dec_outputs)
|
||||
* output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
|
||||
if postnet_outputs is not None:
|
||||
mel_loss = torch.sum(
|
||||
self.criterion(mel_targets, postnet_outputs)
|
||||
* output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
else:
|
||||
mel_loss = 0.0
|
||||
|
||||
return mel_loss_, mel_loss
|
||||
|
||||
|
||||
class ProsodyReconLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='mae'):
|
||||
super(ProsodyReconLoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
if loss_type == 'mae':
|
||||
self.criterion = torch.nn.L1Loss(reduction='none')
|
||||
elif loss_type == 'mse':
|
||||
self.criterion = torch.nn.MSELoss(reduction='none')
|
||||
else:
|
||||
raise ValueError('Unknown loss type: {}'.format(loss_type))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_lengths,
|
||||
duration_targets,
|
||||
pitch_targets,
|
||||
energy_targets,
|
||||
log_duration_predictions,
|
||||
pitch_predictions,
|
||||
energy_predictions,
|
||||
):
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=duration_targets.size(1))
|
||||
input_masks = ~input_masks
|
||||
valid_inputs = input_masks.sum()
|
||||
|
||||
dur_loss = (
|
||||
torch.sum(
|
||||
self.criterion(
|
||||
torch.log(duration_targets.float() + 1),
|
||||
log_duration_predictions) * input_masks) / valid_inputs)
|
||||
pitch_loss = (
|
||||
torch.sum(
|
||||
self.criterion(pitch_targets, pitch_predictions) * input_masks)
|
||||
/ valid_inputs)
|
||||
energy_loss = (
|
||||
torch.sum(
|
||||
self.criterion(energy_targets, energy_predictions)
|
||||
* input_masks) / valid_inputs)
|
||||
|
||||
return dur_loss, pitch_loss, energy_loss
|
||||
|
||||
|
||||
class FpCELoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='ce', weight=[1, 4, 4, 8]):
|
||||
super(FpCELoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
weight_ce = torch.FloatTensor(weight).cuda()
|
||||
self.criterion = torch.nn.CrossEntropyLoss(
|
||||
weight=weight_ce, reduction='none')
|
||||
|
||||
def forward(self, input_lengths, fp_pd, fp_label):
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=fp_label.size(1))
|
||||
input_masks = ~input_masks
|
||||
valid_inputs = input_masks.sum()
|
||||
|
||||
fp_loss = (
|
||||
torch.sum(
|
||||
self.criterion(fp_pd.transpose(2, 1), fp_label) * input_masks)
|
||||
/ valid_inputs)
|
||||
|
||||
return fp_loss
|
||||
|
||||
|
||||
class GeneratorAdversarialLoss(torch.nn.Module):
|
||||
"""Generator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators=True,
|
||||
loss_type='mse',
|
||||
):
|
||||
"""Initialize GeneratorAversarialLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
|
||||
if loss_type == 'mse':
|
||||
self.criterion = self._mse_loss
|
||||
else:
|
||||
self.criterion = self._hinge_loss
|
||||
|
||||
def forward(self, outputs):
|
||||
"""Calcualate generator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs.
|
||||
|
||||
Returns:
|
||||
Tensor: Generator adversarial loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
adv_loss = 0.0
|
||||
for i, outputs_ in enumerate(outputs):
|
||||
adv_loss += self.criterion(outputs_)
|
||||
if self.average_by_discriminators:
|
||||
adv_loss /= i + 1
|
||||
else:
|
||||
adv_loss = self.criterion(outputs)
|
||||
|
||||
return adv_loss
|
||||
|
||||
def _mse_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _hinge_loss(self, x):
|
||||
return -x.mean()
|
||||
|
||||
|
||||
class DiscriminatorAdversarialLoss(torch.nn.Module):
|
||||
"""Discriminator adversarial loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_discriminators=True,
|
||||
loss_type='mse',
|
||||
):
|
||||
"""Initialize DiscriminatorAversarialLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
assert loss_type in ['mse', 'hinge'], f'{loss_type} is not supported.'
|
||||
if loss_type == 'mse':
|
||||
self.fake_criterion = self._mse_fake_loss
|
||||
self.real_criterion = self._mse_real_loss
|
||||
else:
|
||||
self.fake_criterion = self._hinge_fake_loss
|
||||
self.real_criterion = self._hinge_real_loss
|
||||
|
||||
def forward(self, outputs_hat, outputs):
|
||||
"""Calcualate discriminator adversarial loss.
|
||||
|
||||
Args:
|
||||
outputs_hat (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs calculated from generator outputs.
|
||||
outputs (Tensor or list): Discriminator outputs or list of
|
||||
discriminator outputs calculated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Discriminator real loss value.
|
||||
Tensor: Discriminator fake loss value.
|
||||
|
||||
"""
|
||||
if isinstance(outputs, (tuple, list)):
|
||||
real_loss = 0.0
|
||||
fake_loss = 0.0
|
||||
for i, (outputs_hat_,
|
||||
outputs_) in enumerate(zip(outputs_hat, outputs)):
|
||||
if isinstance(outputs_hat_, (tuple, list)):
|
||||
# NOTE(kan-bayashi): case including feature maps
|
||||
outputs_hat_ = outputs_hat_[-1]
|
||||
outputs_ = outputs_[-1]
|
||||
real_loss += self.real_criterion(outputs_)
|
||||
fake_loss += self.fake_criterion(outputs_hat_)
|
||||
if self.average_by_discriminators:
|
||||
fake_loss /= i + 1
|
||||
real_loss /= i + 1
|
||||
else:
|
||||
real_loss = self.real_criterion(outputs)
|
||||
fake_loss = self.fake_criterion(outputs_hat)
|
||||
|
||||
return real_loss, fake_loss
|
||||
|
||||
def _mse_real_loss(self, x):
|
||||
return F.mse_loss(x, x.new_ones(x.size()))
|
||||
|
||||
def _mse_fake_loss(self, x):
|
||||
return F.mse_loss(x, x.new_zeros(x.size()))
|
||||
|
||||
def _hinge_real_loss(self, x):
|
||||
return -torch.mean(torch.min(x - 1, x.new_zeros(x.size())))
|
||||
|
||||
def _hinge_fake_loss(self, x):
|
||||
return -torch.mean(torch.min(-x - 1, x.new_zeros(x.size())))
|
||||
|
||||
|
||||
class FeatureMatchLoss(torch.nn.Module):
|
||||
"""Feature matching loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
average_by_layers=True,
|
||||
average_by_discriminators=True,
|
||||
):
|
||||
"""Initialize FeatureMatchLoss module."""
|
||||
super().__init__()
|
||||
self.average_by_layers = average_by_layers
|
||||
self.average_by_discriminators = average_by_discriminators
|
||||
|
||||
def forward(self, feats_hat, feats):
|
||||
"""Calcualate feature matching loss.
|
||||
|
||||
Args:
|
||||
feats_hat (list): List of list of discriminator outputs
|
||||
calcuated from generater outputs.
|
||||
feats (list): List of list of discriminator outputs
|
||||
calcuated from groundtruth.
|
||||
|
||||
Returns:
|
||||
Tensor: Feature matching loss value.
|
||||
|
||||
"""
|
||||
feat_match_loss = 0.0
|
||||
for i, (feats_hat_, feats_) in enumerate(zip(feats_hat, feats)):
|
||||
feat_match_loss_ = 0.0
|
||||
for j, (feat_hat_, feat_) in enumerate(zip(feats_hat_, feats_)):
|
||||
feat_match_loss_ += F.l1_loss(feat_hat_, feat_.detach())
|
||||
if self.average_by_layers:
|
||||
feat_match_loss_ /= j + 1
|
||||
feat_match_loss += feat_match_loss_
|
||||
if self.average_by_discriminators:
|
||||
feat_match_loss /= i + 1
|
||||
|
||||
return feat_match_loss
|
||||
|
||||
|
||||
class MelSpectrogramLoss(torch.nn.Module):
|
||||
"""Mel-spectrogram loss."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs=22050,
|
||||
fft_size=1024,
|
||||
hop_size=256,
|
||||
win_length=None,
|
||||
window='hann',
|
||||
num_mels=80,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
eps=1e-10,
|
||||
log_base=10.0,
|
||||
):
|
||||
"""Initialize Mel-spectrogram loss."""
|
||||
super().__init__()
|
||||
self.mel_spectrogram = MelSpectrogram(
|
||||
fs=fs,
|
||||
fft_size=fft_size,
|
||||
hop_size=hop_size,
|
||||
win_length=win_length,
|
||||
window=window,
|
||||
num_mels=num_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
center=center,
|
||||
normalized=normalized,
|
||||
onesided=onesided,
|
||||
eps=eps,
|
||||
log_base=log_base,
|
||||
)
|
||||
|
||||
def forward(self, y_hat, y):
|
||||
"""Calculate Mel-spectrogram loss.
|
||||
|
||||
Args:
|
||||
y_hat (Tensor): Generated single tensor (B, 1, T).
|
||||
y (Tensor): Groundtruth single tensor (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram loss value.
|
||||
|
||||
"""
|
||||
mel_hat = self.mel_spectrogram(y_hat)
|
||||
mel = self.mel_spectrogram(y)
|
||||
mel_loss = F.l1_loss(mel_hat, mel)
|
||||
|
||||
return mel_loss
|
||||
|
||||
|
||||
class SpectralConvergenceLoss(torch.nn.Module):
|
||||
"""Spectral convergence loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize spectral convergence loss module."""
|
||||
super(SpectralConvergenceLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
|
||||
"""
|
||||
return torch.norm(y_mag - x_mag, p='fro') / torch.norm(y_mag, p='fro')
|
||||
|
||||
|
||||
class LogSTFTMagnitudeLoss(torch.nn.Module):
|
||||
"""Log STFT magnitude loss module."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initilize los STFT magnitude loss module."""
|
||||
super(LogSTFTMagnitudeLoss, self).__init__()
|
||||
|
||||
def forward(self, x_mag, y_mag):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x_mag (Tensor): Magnitude spectrogram of predicted signal (B, #frames, #freq_bins).
|
||||
y_mag (Tensor): Magnitude spectrogram of groundtruth signal (B, #frames, #freq_bins).
|
||||
|
||||
Returns:
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
return F.l1_loss(torch.log(y_mag), torch.log(x_mag))
|
||||
|
||||
|
||||
class STFTLoss(torch.nn.Module):
|
||||
"""STFT loss module."""
|
||||
|
||||
def __init__(self,
|
||||
fft_size=1024,
|
||||
shift_size=120,
|
||||
win_length=600,
|
||||
window='hann_window'):
|
||||
"""Initialize STFT loss module."""
|
||||
super(STFTLoss, self).__init__()
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.spectral_convergence_loss = SpectralConvergenceLoss()
|
||||
self.log_stft_magnitude_loss = LogSTFTMagnitudeLoss()
|
||||
# NOTE(kan-bayashi): Use register_buffer to fix #223
|
||||
self.register_buffer('window', getattr(torch, window)(win_length))
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T).
|
||||
y (Tensor): Groundtruth signal (B, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Spectral convergence loss value.
|
||||
Tensor: Log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
x_mag = stft(x, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
y_mag = stft(y, self.fft_size, self.shift_size, self.win_length,
|
||||
self.window)
|
||||
sc_loss = self.spectral_convergence_loss(x_mag, y_mag)
|
||||
mag_loss = self.log_stft_magnitude_loss(x_mag, y_mag)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class MultiResolutionSTFTLoss(torch.nn.Module):
|
||||
"""Multi resolution STFT loss module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window='hann_window',
|
||||
):
|
||||
"""Initialize Multi resolution STFT loss module.
|
||||
|
||||
Args:
|
||||
fft_sizes (list): List of FFT sizes.
|
||||
hop_sizes (list): List of hop sizes.
|
||||
win_lengths (list): List of window lengths.
|
||||
window (str): Window function type.
|
||||
|
||||
"""
|
||||
super(MultiResolutionSTFTLoss, self).__init__()
|
||||
assert len(fft_sizes) == len(hop_sizes) == len(win_lengths)
|
||||
self.stft_losses = torch.nn.ModuleList()
|
||||
for fs, ss, wl in zip(fft_sizes, hop_sizes, win_lengths):
|
||||
self.stft_losses += [STFTLoss(fs, ss, wl, window)]
|
||||
|
||||
def forward(self, x, y):
|
||||
"""Calculate forward propagation.
|
||||
|
||||
Args:
|
||||
x (Tensor): Predicted signal (B, T) or (B, #subband, T).
|
||||
y (Tensor): Groundtruth signal (B, T) or (B, #subband, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Multi resolution spectral convergence loss value.
|
||||
Tensor: Multi resolution log STFT magnitude loss value.
|
||||
|
||||
"""
|
||||
if len(x.shape) == 3:
|
||||
x = x.view(-1, x.size(2)) # (B, C, T) -> (B x C, T)
|
||||
y = y.view(-1, y.size(2)) # (B, C, T) -> (B x C, T)
|
||||
sc_loss = 0.0
|
||||
mag_loss = 0.0
|
||||
for f in self.stft_losses:
|
||||
sc_l, mag_l = f(x, y)
|
||||
sc_loss += sc_l
|
||||
mag_loss += mag_l
|
||||
sc_loss /= len(self.stft_losses)
|
||||
mag_loss /= len(self.stft_losses)
|
||||
|
||||
return sc_loss, mag_loss
|
||||
|
||||
|
||||
class SeqCELoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, loss_type='ce'):
|
||||
super(SeqCELoss, self).__init__()
|
||||
self.loss_type = loss_type
|
||||
self.criterion = torch.nn.CrossEntropyLoss(reduction='none')
|
||||
|
||||
def forward(self, logits, targets, masks):
|
||||
loss = self.criterion(logits.contiguous().view(-1, logits.size(-1)),
|
||||
targets.contiguous().view(-1))
|
||||
preds = torch.argmax(logits, dim=-1).contiguous().view(-1)
|
||||
masks = masks.contiguous().view(-1)
|
||||
|
||||
loss = (loss * masks).sum() / masks.sum()
|
||||
err = torch.sum((preds != targets.view(-1)) * masks) / masks.sum()
|
||||
|
||||
return loss, err
|
||||
|
||||
|
||||
class AttentionBinarizationLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, start_epoch=0, warmup_epoch=100):
|
||||
super(AttentionBinarizationLoss, self).__init__()
|
||||
self.start_epoch = start_epoch
|
||||
self.warmup_epoch = warmup_epoch
|
||||
|
||||
def forward(self, epoch, hard_attention, soft_attention, eps=1e-12):
|
||||
log_sum = torch.log(
|
||||
torch.clamp(soft_attention[hard_attention == 1], min=eps)).sum()
|
||||
kl_loss = -log_sum / hard_attention.sum()
|
||||
if epoch < self.start_epoch:
|
||||
warmup_ratio = 0
|
||||
else:
|
||||
warmup_ratio = min(1.0,
|
||||
(epoch - self.start_epoch) / self.warmup_epoch)
|
||||
return kl_loss * warmup_ratio
|
||||
|
||||
|
||||
class AttentionCTCLoss(torch.nn.Module):
|
||||
|
||||
def __init__(self, blank_logprob=-1):
|
||||
super(AttentionCTCLoss, self).__init__()
|
||||
self.log_softmax = torch.nn.LogSoftmax(dim=3)
|
||||
self.blank_logprob = blank_logprob
|
||||
self.CTCLoss = torch.nn.CTCLoss(zero_infinity=True)
|
||||
|
||||
def forward(self, attn_logprob, in_lens, out_lens):
|
||||
key_lens = in_lens
|
||||
query_lens = out_lens
|
||||
attn_logprob_padded = F.pad(
|
||||
input=attn_logprob,
|
||||
pad=(1, 0, 0, 0, 0, 0, 0, 0),
|
||||
value=self.blank_logprob)
|
||||
cost_total = 0.0
|
||||
for bid in range(attn_logprob.shape[0]):
|
||||
target_seq = torch.arange(1, key_lens[bid] + 1).unsqueeze(0)
|
||||
curr_logprob = attn_logprob_padded[bid].permute(1, 0, 2)
|
||||
curr_logprob = curr_logprob[:query_lens[bid], :, :key_lens[bid]
|
||||
+ 1]
|
||||
curr_logprob = self.log_softmax(curr_logprob[None])[0]
|
||||
ctc_cost = self.CTCLoss(
|
||||
curr_logprob,
|
||||
target_seq,
|
||||
input_lengths=query_lens[bid:bid + 1],
|
||||
target_lengths=key_lens[bid:bid + 1],
|
||||
)
|
||||
cost_total += ctc_cost
|
||||
cost = cost_total / attn_logprob.shape[0]
|
||||
return cost
|
||||
|
||||
|
||||
loss_dict = {
|
||||
'generator_adv_loss': GeneratorAdversarialLoss,
|
||||
'discriminator_adv_loss': DiscriminatorAdversarialLoss,
|
||||
'stft_loss': MultiResolutionSTFTLoss,
|
||||
'mel_loss': MelSpectrogramLoss,
|
||||
'subband_stft_loss': MultiResolutionSTFTLoss,
|
||||
'feat_match_loss': FeatureMatchLoss,
|
||||
'MelReconLoss': MelReconLoss,
|
||||
'ProsodyReconLoss': ProsodyReconLoss,
|
||||
'SeqCELoss': SeqCELoss,
|
||||
'AttentionBinarizationLoss': AttentionBinarizationLoss,
|
||||
'AttentionCTCLoss': AttentionCTCLoss,
|
||||
'FpCELoss': FpCELoss,
|
||||
}
|
||||
|
||||
|
||||
def criterion_builder(config, device='cpu'):
|
||||
"""Criterion builder.
|
||||
Args:
|
||||
config (dict): Config dictionary.
|
||||
Returns:
|
||||
criterion (dict): Loss dictionary
|
||||
"""
|
||||
criterion = {}
|
||||
for key, value in config['Loss'].items():
|
||||
if key in loss_dict:
|
||||
if value['enable']:
|
||||
criterion[key] = loss_dict[key](
|
||||
**value.get('params', {})).to(device)
|
||||
setattr(criterion[key], 'weights', value.get('weights', 1.0))
|
||||
else:
|
||||
raise NotImplementedError('{} is not implemented'.format(key))
|
||||
|
||||
return criterion
|
||||
44
modelscope/models/audio/tts/kantts/train/scheduler.py
Normal file
44
modelscope/models/audio/tts/kantts/train/scheduler.py
Normal file
@@ -0,0 +1,44 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from torch.optim.lr_scheduler import MultiStepLR, _LRScheduler
|
||||
|
||||
|
||||
class FindLR(_LRScheduler):
|
||||
"""
|
||||
inspired by fast.ai @https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, max_steps, max_lr=10):
|
||||
self.max_steps = max_steps
|
||||
self.max_lr = max_lr
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
return [
|
||||
base_lr * ((self.max_lr / base_lr)**(
|
||||
self.last_epoch / # noqa W504
|
||||
(self.max_steps - 1))) for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
||||
class NoamLR(_LRScheduler):
|
||||
"""
|
||||
Implements the Noam Learning rate schedule. This corresponds to increasing the learning rate
|
||||
linearly for the first ``warmup_steps`` training steps, and decreasing it thereafter proportionally
|
||||
to the inverse square root of the step number, scaled by the inverse square root of the
|
||||
dimensionality of the model. Time will tell if this is just madness or it's actually important.
|
||||
Parameters
|
||||
----------
|
||||
warmup_steps: ``int``, required.
|
||||
The number of steps to linearly increase the learning rate.
|
||||
"""
|
||||
|
||||
def __init__(self, optimizer, warmup_steps):
|
||||
self.warmup_steps = warmup_steps
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
last_epoch = max(1, self.last_epoch)
|
||||
scale = self.warmup_steps**0.5 * min(
|
||||
last_epoch**(-0.5), last_epoch * self.warmup_steps**(-1.5))
|
||||
return [base_lr * scale for base_lr in self.base_lrs]
|
||||
1201
modelscope/models/audio/tts/kantts/train/trainer.py
Normal file
1201
modelscope/models/audio/tts/kantts/train/trainer.py
Normal file
File diff suppressed because it is too large
Load Diff
188
modelscope/models/audio/tts/kantts/utils/audio_torch.py
Normal file
188
modelscope/models/audio/tts/kantts/utils/audio_torch.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
|
||||
"""
|
||||
if is_pytorch_17plus:
|
||||
x_stft = torch.stft(
|
||||
x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||
else:
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return 20 * torch.log10(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
|
||||
def dynamic_range_decompression_torch(x, C=1):
|
||||
return torch.pow(10.0, x * 0.05) / C
|
||||
|
||||
|
||||
def spectral_normalize_torch(
|
||||
magnitudes,
|
||||
min_level_db=-100.0,
|
||||
ref_level_db=20.0,
|
||||
norm_abs_value=4.0,
|
||||
symmetric=True,
|
||||
):
|
||||
output = dynamic_range_compression_torch(magnitudes) - ref_level_db
|
||||
|
||||
if symmetric:
|
||||
return torch.clamp(
|
||||
2 * norm_abs_value * ((output - min_level_db) / # noqa W504
|
||||
(-min_level_db)) - norm_abs_value,
|
||||
min=-norm_abs_value,
|
||||
max=norm_abs_value)
|
||||
else:
|
||||
return torch.clamp(
|
||||
norm_abs_value * ((output - min_level_db) / (-min_level_db)),
|
||||
min=0.0,
|
||||
max=norm_abs_value)
|
||||
|
||||
|
||||
def spectral_de_normalize_torch(
|
||||
magnitudes,
|
||||
min_level_db=-100.0,
|
||||
ref_level_db=20.0,
|
||||
norm_abs_value=4.0,
|
||||
symmetric=True,
|
||||
):
|
||||
if symmetric:
|
||||
magnitudes = torch.clamp(
|
||||
magnitudes, min=-norm_abs_value, max=norm_abs_value)
|
||||
magnitudes = (magnitudes + norm_abs_value) * (-min_level_db) / (
|
||||
2 * norm_abs_value) + min_level_db
|
||||
else:
|
||||
magnitudes = torch.clamp(magnitudes, min=0.0, max=norm_abs_value)
|
||||
magnitudes = (magnitudes) * (-min_level_db) / (
|
||||
norm_abs_value) + min_level_db
|
||||
|
||||
output = dynamic_range_decompression_torch(magnitudes + ref_level_db)
|
||||
return output
|
||||
|
||||
|
||||
class MelSpectrogram(torch.nn.Module):
|
||||
"""Calculate Mel-spectrogram."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fs=22050,
|
||||
fft_size=1024,
|
||||
hop_size=256,
|
||||
win_length=None,
|
||||
window='hann',
|
||||
num_mels=80,
|
||||
fmin=80,
|
||||
fmax=7600,
|
||||
center=True,
|
||||
normalized=False,
|
||||
onesided=True,
|
||||
eps=1e-10,
|
||||
log_base=10.0,
|
||||
pad_mode='constant',
|
||||
):
|
||||
"""Initialize MelSpectrogram module."""
|
||||
super().__init__()
|
||||
self.fft_size = fft_size
|
||||
if win_length is None:
|
||||
self.win_length = fft_size
|
||||
else:
|
||||
self.win_length = win_length
|
||||
self.hop_size = hop_size
|
||||
self.center = center
|
||||
self.normalized = normalized
|
||||
self.onesided = onesided
|
||||
if window is not None and not hasattr(torch, f'{window}_window'):
|
||||
raise ValueError(f'{window} window is not implemented')
|
||||
self.window = window
|
||||
self.eps = eps
|
||||
self.pad_mode = pad_mode
|
||||
|
||||
fmin = 0 if fmin is None else fmin
|
||||
fmax = fs / 2 if fmax is None else fmax
|
||||
melmat = librosa.filters.mel(
|
||||
sr=fs,
|
||||
n_fft=fft_size,
|
||||
n_mels=num_mels,
|
||||
fmin=fmin,
|
||||
fmax=fmax,
|
||||
)
|
||||
self.register_buffer('melmat', torch.from_numpy(melmat.T).float())
|
||||
self.stft_params = {
|
||||
'n_fft': self.fft_size,
|
||||
'win_length': self.win_length,
|
||||
'hop_length': self.hop_size,
|
||||
'center': self.center,
|
||||
'normalized': self.normalized,
|
||||
'onesided': self.onesided,
|
||||
'pad_mode': self.pad_mode,
|
||||
}
|
||||
if is_pytorch_17plus:
|
||||
self.stft_params['return_complex'] = False
|
||||
|
||||
self.log_base = log_base
|
||||
if self.log_base is None:
|
||||
self.log = torch.log
|
||||
elif self.log_base == 2.0:
|
||||
self.log = torch.log2
|
||||
elif self.log_base == 10.0:
|
||||
self.log = torch.log10
|
||||
else:
|
||||
raise ValueError(f'log_base: {log_base} is not supported.')
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate Mel-spectrogram.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input waveform tensor (B, T) or (B, 1, T).
|
||||
|
||||
Returns:
|
||||
Tensor: Mel-spectrogram (B, #mels, #frames).
|
||||
|
||||
"""
|
||||
if x.dim() == 3:
|
||||
# (B, C, T) -> (B*C, T)
|
||||
x = x.reshape(-1, x.size(2))
|
||||
|
||||
if self.window is not None:
|
||||
window_func = getattr(torch, f'{self.window}_window')
|
||||
window = window_func(
|
||||
self.win_length, dtype=x.dtype, device=x.device)
|
||||
else:
|
||||
window = None
|
||||
|
||||
x_stft = torch.stft(x, window=window, **self.stft_params)
|
||||
# (B, #freqs, #frames, 2) -> (B, $frames, #freqs, 2)
|
||||
x_stft = x_stft.transpose(1, 2)
|
||||
x_power = x_stft[..., 0]**2 + x_stft[..., 1]**2
|
||||
x_amp = torch.sqrt(torch.clamp(x_power, min=self.eps))
|
||||
|
||||
x_mel = torch.matmul(x_amp, self.melmat)
|
||||
x_mel = torch.clamp(x_mel, min=self.eps)
|
||||
x_mel = spectral_normalize_torch(x_mel)
|
||||
|
||||
# return self.log(x_mel).transpose(1, 2)
|
||||
return x_mel.transpose(1, 2)
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import ttsfrd
|
||||
|
||||
|
||||
def text_to_mit_symbols(texts, resources_dir, speaker):
|
||||
fe = ttsfrd.TtsFrontendEngine()
|
||||
fe.initialize(resources_dir)
|
||||
fe.set_lang_type('Zh-CN')
|
||||
|
||||
symbols_lst = []
|
||||
for idx, text in enumerate(texts):
|
||||
text = text.strip()
|
||||
res = fe.gen_tacotron_symbols(text)
|
||||
res = res.replace('F7', speaker)
|
||||
sentences = res.split('\n')
|
||||
for sentence in sentences:
|
||||
arr = sentence.split('\t')
|
||||
# skip the empty line
|
||||
if len(arr) != 2:
|
||||
continue
|
||||
sub_index, symbols = sentence.split('\t')
|
||||
symbol_str = '{}_{}\t{}\n'.format(idx, sub_index, symbols)
|
||||
symbols_lst.append(symbol_str)
|
||||
|
||||
return symbols_lst
|
||||
@@ -21,24 +21,21 @@ _whitespace_re = re.compile(r'\s+')
|
||||
# List of (regular expression, replacement) pairs for abbreviations:
|
||||
_abbreviations = [
|
||||
(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1])
|
||||
for x in [('mrs', 'misess'),
|
||||
('mr', 'mister'),
|
||||
('dr', 'doctor'),
|
||||
('st', 'saint'),
|
||||
('co', 'company'),
|
||||
('jr', 'junior'),
|
||||
('maj', 'major'),
|
||||
('gen', 'general'),
|
||||
('drs', 'doctors'),
|
||||
('rev', 'reverend'),
|
||||
('lt', 'lieutenant'),
|
||||
('hon', 'honorable'),
|
||||
('sgt', 'sergeant'),
|
||||
('capt', 'captain'),
|
||||
('esq', 'esquire'),
|
||||
('ltd', 'limited'),
|
||||
('col', 'colonel'),
|
||||
('ft', 'fort'), ]] # yapf:disable
|
||||
for x in [('mrs', 'misess'), ('mr', 'mister'), (
|
||||
'dr', 'doctor'), ('st', 'saint'), ('co', 'company'), (
|
||||
'jr',
|
||||
'junior'), ('maj', 'major'), ('gen', 'general'), (
|
||||
'drs', 'doctors'), ('rev', 'reverend'), (
|
||||
'lt',
|
||||
'lieutenant'), ('hon', 'honorable'), (
|
||||
'sgt',
|
||||
'sergeant'), ('capt', 'captain'), (
|
||||
'esq',
|
||||
'esquire'), ('ltd',
|
||||
'limited'), ('col',
|
||||
'colonel'), ('ft',
|
||||
'fort')]
|
||||
]
|
||||
|
||||
|
||||
def expand_abbreviations(text):
|
||||
@@ -64,14 +61,14 @@ def convert_to_ascii(text):
|
||||
|
||||
|
||||
def basic_cleaners(text):
|
||||
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
|
||||
"""Basic pipeline that lowercases and collapses whitespace without transliteration."""
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
return text
|
||||
|
||||
|
||||
def transliteration_cleaners(text):
|
||||
'''Pipeline for non-English text that transliterates to ASCII.'''
|
||||
"""Pipeline for non-English text that transliterates to ASCII."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = collapse_whitespace(text)
|
||||
@@ -79,7 +76,7 @@ def transliteration_cleaners(text):
|
||||
|
||||
|
||||
def english_cleaners(text):
|
||||
'''Pipeline for English text, including number and abbreviation expansion.'''
|
||||
"""Pipeline for English text, including number and abbreviation expansion."""
|
||||
text = convert_to_ascii(text)
|
||||
text = lowercase(text)
|
||||
text = expand_numbers(text)
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
emotion_types = [
|
||||
'emotion_none',
|
||||
'emotion_neutral',
|
||||
'emotion_angry',
|
||||
'emotion_disgust',
|
||||
'emotion_fear',
|
||||
'emotion_happy',
|
||||
'emotion_sad',
|
||||
'emotion_surprise',
|
||||
'emotion_calm',
|
||||
'emotion_gentle',
|
||||
'emotion_relax',
|
||||
'emotion_lyrical',
|
||||
'emotion_serious',
|
||||
'emotion_disgruntled',
|
||||
'emotion_satisfied',
|
||||
'emotion_disappointed',
|
||||
'emotion_excited',
|
||||
'emotion_anxiety',
|
||||
'emotion_jealousy',
|
||||
'emotion_hate',
|
||||
'emotion_pity',
|
||||
'emotion_pleasure',
|
||||
'emotion_arousal',
|
||||
'emotion_dominance',
|
||||
'emotion_placeholder1',
|
||||
'emotion_placeholder2',
|
||||
'emotion_placeholder3',
|
||||
'emotion_placeholder4',
|
||||
'emotion_placeholder5',
|
||||
'emotion_placeholder6',
|
||||
'emotion_placeholder7',
|
||||
'emotion_placeholder8',
|
||||
'emotion_placeholder9',
|
||||
]
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import xml.etree.ElementTree as ET
|
||||
|
||||
from modelscope.models.audio.tts.kantts.preprocess.languages import languages
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logging = get_logger()
|
||||
|
||||
syllable_flags = [
|
||||
's_begin',
|
||||
's_end',
|
||||
's_none',
|
||||
's_both',
|
||||
's_middle',
|
||||
]
|
||||
|
||||
word_segments = [
|
||||
'word_begin',
|
||||
'word_end',
|
||||
'word_middle',
|
||||
'word_both',
|
||||
'word_none',
|
||||
]
|
||||
|
||||
|
||||
def parse_phoneset(phoneset_file):
|
||||
"""Parse a phoneset file and return a list of symbols.
|
||||
Args:
|
||||
phoneset_file (str): Path to the phoneset file.
|
||||
|
||||
Returns:
|
||||
list: A list of phones.
|
||||
"""
|
||||
ns = '{http://schemas.alibaba-inc.com/tts}'
|
||||
|
||||
phone_lst = []
|
||||
phoneset_root = ET.parse(phoneset_file).getroot()
|
||||
for phone_node in phoneset_root.findall(ns + 'phone'):
|
||||
phone_lst.append(phone_node.find(ns + 'name').text)
|
||||
|
||||
for i in range(1, 5):
|
||||
phone_lst.append('#{}'.format(i))
|
||||
|
||||
return phone_lst
|
||||
|
||||
|
||||
def parse_tonelist(tonelist_file):
|
||||
"""Parse a tonelist file and return a list of tones.
|
||||
Args:
|
||||
tonelist_file (str): Path to the tonelist file.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary of tones.
|
||||
"""
|
||||
tone_lst = []
|
||||
with open(tonelist_file, 'r') as f:
|
||||
lines = f.readlines()
|
||||
for line in lines:
|
||||
tone = line.strip()
|
||||
if tone != '':
|
||||
tone_lst.append('tone{}'.format(tone))
|
||||
else:
|
||||
tone_lst.append('tone_none')
|
||||
|
||||
return tone_lst
|
||||
|
||||
|
||||
def get_language_symbols(language, language_dir):
|
||||
"""Get symbols of a language.
|
||||
Args:
|
||||
language (str): Language name.
|
||||
"""
|
||||
language_dict = languages.get(language, None)
|
||||
if language_dict is None:
|
||||
logging.error('Language %s not supported. Using PinYin as default',
|
||||
language)
|
||||
language_dict = languages['PinYin']
|
||||
language = 'PinYin'
|
||||
|
||||
language_dir = os.path.join(language_dir, language)
|
||||
phoneset_file = os.path.join(language_dir, language_dict['phoneset_path'])
|
||||
tonelist_file = os.path.join(language_dir, language_dict['tonelist_path'])
|
||||
phones = parse_phoneset(phoneset_file)
|
||||
tones = parse_tonelist(tonelist_file)
|
||||
|
||||
return phones, tones, syllable_flags, word_segments
|
||||
@@ -1,15 +1,15 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import abc
|
||||
import codecs
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
|
||||
from . import cleaners as cleaners
|
||||
from .emotion_types import emotion_types
|
||||
from .lang_symbols import get_language_symbols
|
||||
|
||||
# Regular expression matching text enclosed in curly braces:
|
||||
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
||||
@@ -19,18 +19,37 @@ def _clean_text(text, cleaner_names):
|
||||
for name in cleaner_names:
|
||||
cleaner = getattr(cleaners, name)
|
||||
if not cleaner:
|
||||
raise Exception(
|
||||
'modelscope error: configuration cleaner unknown: %s' % name)
|
||||
raise Exception('Unknown cleaner: %s' % name)
|
||||
text = cleaner(text)
|
||||
return text
|
||||
|
||||
|
||||
def get_fpdict(config):
|
||||
# eomtion_neutral(F7) can be other emotion(speaker) types in the corresponding list in config file.
|
||||
en_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {en_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
|
||||
a_sy = '{ga$tone5$s_begin$word_begin$emotion_neutral$F7} {a_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
|
||||
e_sy = '{ge$tone5$s_begin$word_begin$emotion_neutral$F7} {e_c$tone5$s_end$word_end$emotion_neutral$F7} {#3$tone_none$s_none$word_none$emotion_neutral$F7}' # NOQA: E501
|
||||
ling_unit = KanTtsLinguisticUnit(config)
|
||||
|
||||
en_lings = ling_unit.encode_symbol_sequence(en_sy)
|
||||
a_lings = ling_unit.encode_symbol_sequence(a_sy)
|
||||
e_lings = ling_unit.encode_symbol_sequence(e_sy)
|
||||
|
||||
en_ling = np.stack(en_lings, axis=1)[:3, :4]
|
||||
a_ling = np.stack(a_lings, axis=1)[:3, :4]
|
||||
e_ling = np.stack(e_lings, axis=1)[:3, :4]
|
||||
|
||||
fp_dict = {1: a_ling, 2: en_ling, 3: e_ling}
|
||||
return fp_dict
|
||||
|
||||
|
||||
class LinguisticBaseUnit(abc.ABC):
|
||||
|
||||
def set_config_params(self, config_params):
|
||||
self.config_params = config_params
|
||||
|
||||
def save(self, config, config_name, path):
|
||||
"""Save config to file"""
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@@ -39,22 +58,37 @@ class LinguisticBaseUnit(abc.ABC):
|
||||
|
||||
class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
|
||||
def __init__(self, config, path, has_mask=True):
|
||||
def __init__(self, config, lang_dir=None):
|
||||
super(KanTtsLinguisticUnit, self).__init__()
|
||||
|
||||
# special symbol
|
||||
self._pad = '_'
|
||||
self._eos = '~'
|
||||
self._mask = '@[MASK]'
|
||||
self._has_mask = has_mask
|
||||
self._unit_config = config
|
||||
self._path = path
|
||||
|
||||
self.unit_config = config['linguistic_unit']
|
||||
self.has_mask = self.unit_config.get('has_mask', True)
|
||||
self.lang_type = self.unit_config.get('language', 'PinYin')
|
||||
(
|
||||
self.lang_phones,
|
||||
self.lang_tones,
|
||||
self.lang_syllable_flags,
|
||||
self.lang_word_segments,
|
||||
) = get_language_symbols(self.lang_type, lang_dir)
|
||||
|
||||
self._cleaner_names = [
|
||||
x.strip() for x in self._unit_config['cleaners'].split(',')
|
||||
x.strip() for x in self.unit_config['cleaners'].split(',')
|
||||
]
|
||||
self._lfeat_type_list = self._unit_config['lfeat_type_list'].strip(
|
||||
).split(',')
|
||||
_lfeat_type_list = self.unit_config['lfeat_type_list'].strip().split(
|
||||
',')
|
||||
self._lfeat_type_list = _lfeat_type_list
|
||||
|
||||
self.fp_enable = config['Model']['KanTtsSAMBERT']['params'].get(
|
||||
'FP', False)
|
||||
if self.fp_enable:
|
||||
self._fpadd_lfeat_type_list = [
|
||||
_lfeat_type_list[0], _lfeat_type_list[4]
|
||||
]
|
||||
|
||||
self.build()
|
||||
|
||||
@@ -79,19 +113,13 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# sy sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_symbols = []
|
||||
|
||||
sy_path = os.path.join(self._path, self._unit_config['sy'])
|
||||
f = codecs.open(sy_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_symbols.append(line)
|
||||
|
||||
_arpabet = ['@' + s for s in _ch_symbols]
|
||||
# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters):
|
||||
# _arpabet = ['@' + s for s in cmudict.valid_symbols]
|
||||
_arpabet = ['@' + s for s in self.lang_phones]
|
||||
|
||||
# Export all symbols:
|
||||
self.sy = list(_characters) + _arpabet + [self._pad, self._eos]
|
||||
if self._has_mask:
|
||||
if self.has_mask:
|
||||
self.sy.append(self._mask)
|
||||
self._sy_to_id = {s: i for i, s in enumerate(self.sy)}
|
||||
self._id_to_sy = {i: s for i, s in enumerate(self.sy)}
|
||||
@@ -101,17 +129,10 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# tone sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_tones = []
|
||||
|
||||
tone_path = os.path.join(self._path, self._unit_config['tone'])
|
||||
f = codecs.open(tone_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_tones.append(line)
|
||||
|
||||
# Export all tones:
|
||||
self.tone = list(_characters) + _ch_tones + [self._pad, self._eos]
|
||||
if self._has_mask:
|
||||
self.tone = (
|
||||
list(_characters) + self.lang_tones + [self._pad, self._eos])
|
||||
if self.has_mask:
|
||||
self.tone.append(self._mask)
|
||||
self._tone_to_id = {s: i for i, s in enumerate(self.tone)}
|
||||
self._id_to_tone = {i: s for i, s in enumerate(self.tone)}
|
||||
@@ -121,20 +142,11 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# syllable flag sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_syllable_flags = []
|
||||
|
||||
sy_flag_path = os.path.join(self._path,
|
||||
self._unit_config['syllable_flag'])
|
||||
f = codecs.open(sy_flag_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_syllable_flags.append(line)
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.syllable_flag = list(_characters) + _ch_syllable_flags + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.syllable_flag = (
|
||||
list(_characters) + self.lang_syllable_flags
|
||||
+ [self._pad, self._eos])
|
||||
if self.has_mask:
|
||||
self.syllable_flag.append(self._mask)
|
||||
self._syllable_flag_to_id = {
|
||||
s: i
|
||||
@@ -150,19 +162,11 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# word segment sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_word_segments = []
|
||||
|
||||
ws_path = os.path.join(self._path, self._unit_config['word_segment'])
|
||||
f = codecs.open(ws_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_word_segments.append(line)
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.word_segment = list(_characters) + _ch_word_segments + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.word_segment = (
|
||||
list(_characters) + self.lang_word_segments
|
||||
+ [self._pad, self._eos])
|
||||
if self.has_mask:
|
||||
self.word_segment.append(self._mask)
|
||||
self._word_segment_to_id = {
|
||||
s: i
|
||||
@@ -179,19 +183,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# emotion category sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_emo_types = []
|
||||
|
||||
emo_path = os.path.join(self._path,
|
||||
self._unit_config['emo_category'])
|
||||
f = codecs.open(emo_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_emo_types.append(line)
|
||||
|
||||
self.emo_category = list(_characters) + _ch_emo_types + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.emo_category = (
|
||||
list(_characters) + emotion_types + [self._pad, self._eos])
|
||||
if self.has_mask:
|
||||
self.emo_category.append(self._mask)
|
||||
self._emo_category_to_id = {
|
||||
s: i
|
||||
@@ -208,20 +202,12 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
# speaker category sub-unit
|
||||
_characters = ''
|
||||
|
||||
_ch_speakers = []
|
||||
|
||||
speaker_path = os.path.join(self._path,
|
||||
self._unit_config['speaker_category'])
|
||||
f = codecs.open(speaker_path, 'r')
|
||||
for line in f:
|
||||
line = line.strip('\r\n')
|
||||
_ch_speakers.append(line)
|
||||
_ch_speakers = self.unit_config['speaker_list'].strip().split(',')
|
||||
|
||||
# Export all syllable_flags:
|
||||
self.speaker = list(_characters) + _ch_speakers + [
|
||||
self._pad, self._eos
|
||||
]
|
||||
if self._has_mask:
|
||||
self.speaker = (
|
||||
list(_characters) + _ch_speakers + [self._pad, self._eos])
|
||||
if self.has_mask:
|
||||
self.speaker.append(self._mask)
|
||||
self._speaker_to_id = {s: i for i, s in enumerate(self.speaker)}
|
||||
self._id_to_speaker = {i: s for i, s in enumerate(self.speaker)}
|
||||
@@ -237,8 +223,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
'$')
|
||||
index = 0
|
||||
while index < len(lfeat_symbol_separate):
|
||||
lfeat_symbol_separate[index] = lfeat_symbol_separate[
|
||||
index] + this_lfeat_symbol[index] + ' '
|
||||
lfeat_symbol_separate[index] = (
|
||||
lfeat_symbol_separate[index] + this_lfeat_symbol[index]
|
||||
+ ' ')
|
||||
index = index + 1
|
||||
|
||||
input_and_label_data = []
|
||||
@@ -271,9 +258,7 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
elif lfeat_type == 'speaker_category':
|
||||
s = self.decode_speaker_category(sequence_item)
|
||||
else:
|
||||
raise Exception(
|
||||
'modelscope error: configuration lfeat type(%s) unknown.'
|
||||
% lfeat_type)
|
||||
raise Exception('Unknown lfeat type: %s' % lfeat_type)
|
||||
result.append('%s:%s' % (lfeat_type, s))
|
||||
|
||||
return result
|
||||
@@ -285,8 +270,9 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
this_lfeat_symbol_format = ''
|
||||
index = 0
|
||||
while index < len(this_lfeat_symbol):
|
||||
this_lfeat_symbol_format = this_lfeat_symbol_format + '{' + this_lfeat_symbol[
|
||||
index] + '}' + ' '
|
||||
this_lfeat_symbol_format = (
|
||||
this_lfeat_symbol_format + '{' + this_lfeat_symbol[index]
|
||||
+ '}' + ' ')
|
||||
index = index + 1
|
||||
sequence = self.encode_text(this_lfeat_symbol_format,
|
||||
self._cleaner_names)
|
||||
@@ -301,9 +287,7 @@ class KanTtsLinguisticUnit(LinguisticBaseUnit):
|
||||
elif lfeat_type == 'speaker_category':
|
||||
sequence = self.encode_speaker_category(this_lfeat_symbol)
|
||||
else:
|
||||
raise Exception(
|
||||
'modelscope error: configuration lfeat type(%s) unknown.'
|
||||
% lfeat_type)
|
||||
raise Exception('Unknown lfeat type: %s' % lfeat_type)
|
||||
|
||||
return sequence
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# The implementation is adopted from tacotron,
|
||||
# made publicly available under the MIT License at https://github.com/keithito/tacotron
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import re
|
||||
|
||||
26
modelscope/models/audio/tts/kantts/utils/log.py
Normal file
26
modelscope/models/audio/tts/kantts/utils/log.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
|
||||
def logging_to_file(log_file):
|
||||
logger = logging.getLogger()
|
||||
handler = logging.FileHandler(log_file)
|
||||
formatter = logging.Formatter(
|
||||
'%(asctime)s %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s',
|
||||
datefmt='%Y-%m-%d:%H:%M:%S',
|
||||
)
|
||||
handler.setFormatter(formatter)
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
|
||||
def get_git_revision_short_hash():
|
||||
return (subprocess.check_output(['git', 'rev-parse', '--short',
|
||||
'HEAD']).decode('ascii').strip())
|
||||
|
||||
|
||||
def get_git_revision_hash():
|
||||
return subprocess.check_output(['git', 'rev-parse',
|
||||
'HEAD']).decode('ascii').strip()
|
||||
39
modelscope/models/audio/tts/kantts/utils/plot.py
Normal file
39
modelscope/models/audio/tts/kantts/utils/plot.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg')
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
except ImportError:
|
||||
raise ImportError('Please install matplotlib.')
|
||||
|
||||
plt.set_loglevel('info')
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(12, 8))
|
||||
im = ax.imshow(
|
||||
spectrogram, aspect='auto', origin='lower', interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_alignment(alignment, info=None):
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(
|
||||
alignment, aspect='auto', origin='lower', interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Input timestep'
|
||||
if info is not None:
|
||||
xlabel += '\t' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Output timestep')
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
@@ -1,238 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .units import KanTtsLinguisticUnit
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class KanTtsText2MelDataset(Dataset):
|
||||
|
||||
def __init__(self, metadata_filename, config_filename, cache=False):
|
||||
super(KanTtsText2MelDataset, self).__init__()
|
||||
|
||||
self.cache = cache
|
||||
|
||||
with open(config_filename, encoding='utf-8') as f:
|
||||
self._config = json.loads(f.read())
|
||||
|
||||
# Load metadata:
|
||||
self._datadir = os.path.dirname(metadata_filename)
|
||||
with open(metadata_filename, encoding='utf-8') as f:
|
||||
self._metadata = [line.strip().split('|') for line in f]
|
||||
self._length_lst = [int(x[2]) for x in self._metadata]
|
||||
hours = sum(
|
||||
self._length_lst) * self._config['audio']['frame_shift_ms'] / (
|
||||
3600 * 1000)
|
||||
|
||||
logger.info('Loaded metadata for %d examples (%.2f hours)' %
|
||||
(len(self._metadata), hours))
|
||||
logger.info('Minimum length: %d, Maximum length: %d' %
|
||||
(min(self._length_lst), max(self._length_lst)))
|
||||
|
||||
self.ling_unit = KanTtsLinguisticUnit(config_filename)
|
||||
self.pad_executor = KanTtsText2MelPad()
|
||||
|
||||
self.r = self._config['am']['outputs_per_step']
|
||||
self.num_mels = self._config['am']['num_mels']
|
||||
|
||||
if 'adv' in self._config:
|
||||
self.feat_window = self._config['adv']['random_window']
|
||||
else:
|
||||
self.feat_window = None
|
||||
logger.info(self.feat_window)
|
||||
|
||||
self.data_cache = [
|
||||
self.cache_load(i) for i in tqdm(range(self.__len__()))
|
||||
] if self.cache else []
|
||||
|
||||
def get_frames_lst(self):
|
||||
return self._length_lst
|
||||
|
||||
def __getitem__(self, index):
|
||||
if self.cache:
|
||||
sample = self.data_cache[index]
|
||||
return sample
|
||||
|
||||
return self.cache_load(index)
|
||||
|
||||
def cache_load(self, index):
|
||||
sample = {}
|
||||
|
||||
meta = self._metadata[index]
|
||||
|
||||
sample['utt_id'] = meta[0]
|
||||
|
||||
sample['mel_target'] = np.load(os.path.join(
|
||||
self._datadir, meta[1]))[:, :self.num_mels]
|
||||
sample['output_length'] = len(sample['mel_target'])
|
||||
|
||||
lfeat_symbol = meta[3]
|
||||
sample['ling'] = self.ling_unit.encode_symbol_sequence(lfeat_symbol)
|
||||
|
||||
sample['duration'] = np.load(os.path.join(self._datadir, meta[4]))
|
||||
|
||||
sample['pitch_contour'] = np.load(os.path.join(self._datadir, meta[5]))
|
||||
|
||||
sample['energy_contour'] = np.load(
|
||||
os.path.join(self._datadir, meta[6]))
|
||||
|
||||
return sample
|
||||
|
||||
def __len__(self):
|
||||
return len(self._metadata)
|
||||
|
||||
def collate_fn(self, batch):
|
||||
data_dict = {}
|
||||
|
||||
max_input_length = max((len(x['ling'][0]) for x in batch))
|
||||
|
||||
# pure linguistic info: sy|tone|syllable_flag|word_segment
|
||||
|
||||
# sy
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[0]
|
||||
inputs_sy = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][0] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
# tone
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[1]
|
||||
inputs_tone = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][1] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# syllable_flag
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[2]
|
||||
inputs_syllable_flag = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][2] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# word_segment
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[3]
|
||||
inputs_ws = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][3] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# emotion category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[4]
|
||||
data_dict['input_emotions'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][4] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
# speaker category
|
||||
lfeat_type = self.ling_unit._lfeat_type_list[5]
|
||||
data_dict['input_speakers'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['ling'][5] for x in batch], max_input_length,
|
||||
self.ling_unit._sub_unit_pad[lfeat_type]).long()
|
||||
|
||||
data_dict['input_lings'] = torch.stack(
|
||||
[inputs_sy, inputs_tone, inputs_syllable_flag, inputs_ws], dim=2)
|
||||
|
||||
data_dict['valid_input_lengths'] = torch.as_tensor(
|
||||
[len(x['ling'][0]) - 1 for x in batch], dtype=torch.long
|
||||
) # There is one '~' in the last of symbol sequence. We put length-1 for calculation.
|
||||
|
||||
data_dict['valid_output_lengths'] = torch.as_tensor(
|
||||
[x['output_length'] for x in batch], dtype=torch.long)
|
||||
max_output_length = torch.max(data_dict['valid_output_lengths']).item()
|
||||
max_output_round_length = self.pad_executor._round_up(
|
||||
max_output_length, self.r)
|
||||
|
||||
if self.feat_window is not None:
|
||||
active_feat_len = np.minimum(max_output_round_length,
|
||||
self.feat_window)
|
||||
if active_feat_len < self.feat_window:
|
||||
max_output_round_length = self.pad_executor._round_up(
|
||||
self.feat_window, self.r)
|
||||
active_feat_len = self.feat_window
|
||||
|
||||
max_offsets = [x['output_length'] - active_feat_len for x in batch]
|
||||
feat_offsets = [
|
||||
np.random.randint(0, np.maximum(1, offset))
|
||||
for offset in max_offsets
|
||||
]
|
||||
feat_offsets = torch.from_numpy(
|
||||
np.asarray(feat_offsets, dtype=np.int32)).long()
|
||||
data_dict['feat_offsets'] = feat_offsets
|
||||
|
||||
data_dict['mel_targets'] = self.pad_executor._prepare_targets(
|
||||
[x['mel_target'] for x in batch], max_output_round_length, 0.0)
|
||||
data_dict['durations'] = self.pad_executor._prepare_durations(
|
||||
[x['duration'] for x in batch], max_input_length,
|
||||
max_output_round_length)
|
||||
|
||||
data_dict['pitch_contours'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['pitch_contour'] for x in batch], max_input_length,
|
||||
0.0).float()
|
||||
data_dict[
|
||||
'energy_contours'] = self.pad_executor._prepare_scalar_inputs(
|
||||
[x['energy_contour'] for x in batch], max_input_length,
|
||||
0.0).float()
|
||||
|
||||
data_dict['utt_ids'] = [x['utt_id'] for x in batch]
|
||||
|
||||
return data_dict
|
||||
|
||||
|
||||
class KanTtsText2MelPad(object):
|
||||
|
||||
def __init__(self):
|
||||
super(KanTtsText2MelPad, self).__init__()
|
||||
pass
|
||||
|
||||
def _pad1D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, (0, length - x.shape[0]), mode='constant', constant_values=pad)
|
||||
|
||||
def _pad2D(self, x, length, pad):
|
||||
return np.pad(
|
||||
x, [(0, length - x.shape[0]), (0, 0)],
|
||||
mode='constant',
|
||||
constant_values=pad)
|
||||
|
||||
def _pad_durations(self, duration, max_in_len, max_out_len):
|
||||
framenum = np.sum(duration)
|
||||
symbolnum = duration.shape[0]
|
||||
if framenum < max_out_len:
|
||||
padframenum = max_out_len - framenum
|
||||
duration = np.insert(
|
||||
duration, symbolnum, values=padframenum, axis=0)
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum + 1,
|
||||
values=[0] * (max_in_len - symbolnum - 1),
|
||||
axis=0)
|
||||
else:
|
||||
if symbolnum < max_in_len:
|
||||
duration = np.insert(
|
||||
duration,
|
||||
symbolnum,
|
||||
values=[0] * (max_in_len - symbolnum),
|
||||
axis=0)
|
||||
return duration
|
||||
|
||||
def _round_up(self, x, multiple):
|
||||
remainder = x % multiple
|
||||
return x if remainder == 0 else x + multiple - remainder
|
||||
|
||||
def _prepare_scalar_inputs(self, inputs, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad1D(x, max_len, pad) for x in inputs]))
|
||||
|
||||
def _prepare_targets(self, targets, max_len, pad):
|
||||
return torch.from_numpy(
|
||||
np.stack([self._pad2D(t, max_len, pad) for t in targets])).float()
|
||||
|
||||
def _prepare_durations(self, durations, max_in_len, max_out_len):
|
||||
return torch.from_numpy(
|
||||
np.stack([
|
||||
self._pad_durations(t, max_in_len, max_out_len)
|
||||
for t in durations
|
||||
])).long()
|
||||
@@ -1,131 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import math
|
||||
import random
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
class LenSortGroupPoolSampler(Sampler):
|
||||
|
||||
def __init__(self, data_source, length_lst, group_size):
|
||||
super(LenSortGroupPoolSampler, self).__init__(data_source)
|
||||
|
||||
self.data_source = data_source
|
||||
self.length_lst = length_lst
|
||||
self.group_size = group_size
|
||||
|
||||
self.num = len(self.length_lst)
|
||||
self.buckets = self.num // group_size
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
def getkey(item):
|
||||
return item[1]
|
||||
|
||||
random_lst = torch.randperm(self.num).tolist()
|
||||
random_len_lst = [(i, self.length_lst[i]) for i in random_lst]
|
||||
|
||||
# Bucket examples based on similar output sequence length for efficiency:
|
||||
groups = [
|
||||
random_len_lst[i:i + self.group_size]
|
||||
for i in range(0, self.num, self.group_size)
|
||||
]
|
||||
if (self.num % self.group_size):
|
||||
groups.append(random_len_lst[self.buckets * self.group_size:-1])
|
||||
|
||||
indices = []
|
||||
|
||||
for group in groups:
|
||||
group.sort(key=getkey, reverse=True)
|
||||
for item in group:
|
||||
indices.append(item[0])
|
||||
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
|
||||
class DistributedLenSortGroupPoolSampler(Sampler):
|
||||
|
||||
def __init__(self,
|
||||
dataset,
|
||||
length_lst,
|
||||
group_size,
|
||||
num_replicas=None,
|
||||
rank=None,
|
||||
shuffle=True):
|
||||
super(DistributedLenSortGroupPoolSampler, self).__init__(dataset)
|
||||
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
'modelscope error: Requires distributed package to be available'
|
||||
)
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError(
|
||||
'modelscope error: Requires distributed package to be available'
|
||||
)
|
||||
rank = dist.get_rank()
|
||||
self.dataset = dataset
|
||||
self.length_lst = length_lst
|
||||
self.group_size = group_size
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.epoch = 0
|
||||
self.num_samples = int(
|
||||
math.ceil(len(self.dataset) * 1.0 / self.num_replicas))
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.buckets = self.num_samples // group_size
|
||||
self.shuffle = shuffle
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
def getkey(item):
|
||||
return item[1]
|
||||
|
||||
# deterministically shuffle based on epoch
|
||||
g = torch.Generator()
|
||||
g.manual_seed(self.epoch)
|
||||
if self.shuffle:
|
||||
indices = torch.randperm(len(self.dataset), generator=g).tolist()
|
||||
else:
|
||||
indices = list(range(len(self.dataset)))
|
||||
|
||||
# add extra samples to make it evenly divisible
|
||||
indices += indices[:(self.total_size - len(indices))]
|
||||
assert len(indices) == self.total_size
|
||||
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
|
||||
random_len_lst = [(i, self.length_lst[i]) for i in indices]
|
||||
|
||||
# Bucket examples based on similar output sequence length for efficiency:
|
||||
groups = [
|
||||
random_len_lst[i:i + self.group_size]
|
||||
for i in range(0, self.num_samples, self.group_size)
|
||||
]
|
||||
if (self.num_samples % self.group_size):
|
||||
groups.append(random_len_lst[self.buckets * self.group_size:-1])
|
||||
|
||||
new_indices = []
|
||||
|
||||
for group in groups:
|
||||
group.sort(key=getkey, reverse=True)
|
||||
for item in group:
|
||||
new_indices.append(item[0])
|
||||
|
||||
return iter(new_indices)
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .ling_unit import * # noqa F403
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .hifigan import * # noqa F403
|
||||
@@ -1,238 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
# Part of the implementation is borrowed from https://github.com/jik876/hifi-gan
|
||||
|
||||
from distutils.version import LooseVersion
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import AvgPool1d, Conv1d, Conv2d, ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm, spectral_norm, weight_norm
|
||||
|
||||
from modelscope.models.audio.tts.models.utils import get_padding, init_weights
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
is_pytorch_17plus = LooseVersion(torch.__version__) >= LooseVersion('1.7')
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B).
|
||||
|
||||
"""
|
||||
if is_pytorch_17plus:
|
||||
x_stft = torch.stft(
|
||||
x, fft_size, hop_size, win_length, window, return_complex=False)
|
||||
else:
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window)
|
||||
real = x_stft[..., 0]
|
||||
imag = x_stft[..., 1]
|
||||
|
||||
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
||||
return torch.sqrt(torch.clamp(real**2 + imag**2, min=1e-7)).transpose(2, 1)
|
||||
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
def get_padding_casual(kernel_size, dilation=1):
|
||||
return int(kernel_size * dilation - dilation)
|
||||
|
||||
|
||||
class Conv1dCasual(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=0,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
bias=True,
|
||||
padding_mode='zeros'):
|
||||
super(Conv1dCasual, self).__init__()
|
||||
self.pad = padding
|
||||
self.conv1d = weight_norm(
|
||||
Conv1d(
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
padding_mode=padding_mode))
|
||||
self.conv1d.apply(init_weights)
|
||||
|
||||
def forward(self, x): # bdt
|
||||
# described starting from the last dimension and moving forward.
|
||||
x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), 'constant')
|
||||
x = self.conv1d(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.conv1d)
|
||||
|
||||
|
||||
class ConvTranspose1dCausal(torch.nn.Module):
|
||||
"""CausalConvTranspose1d module with customized initialization."""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
padding=0):
|
||||
"""Initialize CausalConvTranspose1d module."""
|
||||
super(ConvTranspose1dCausal, self).__init__()
|
||||
self.deconv = weight_norm(
|
||||
ConvTranspose1d(in_channels, out_channels, kernel_size, stride))
|
||||
self.stride = stride
|
||||
self.deconv.apply(init_weights)
|
||||
self.pad = kernel_size - stride
|
||||
|
||||
def forward(self, x):
|
||||
"""Calculate forward propagation.
|
||||
Args:
|
||||
x (Tensor): Input tensor (B, in_channels, T_in).
|
||||
Returns:
|
||||
Tensor: Output tensor (B, out_channels, T_out).
|
||||
"""
|
||||
# x = F.pad(x, (self.pad, 0, 0, 0, 0, 0), "constant")
|
||||
return self.deconv(x)[:, :, :-self.pad]
|
||||
|
||||
def remove_weight_norm(self):
|
||||
remove_weight_norm(self.deconv)
|
||||
|
||||
|
||||
class ResBlock1(torch.nn.Module):
|
||||
|
||||
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
||||
super(ResBlock1, self).__init__()
|
||||
self.h = h
|
||||
self.convs1 = nn.ModuleList([
|
||||
Conv1dCasual(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation[i],
|
||||
padding=get_padding_casual(kernel_size, dilation[i]))
|
||||
for i in range(len(dilation))
|
||||
])
|
||||
|
||||
self.convs2 = nn.ModuleList([
|
||||
Conv1dCasual(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding_casual(kernel_size, 1))
|
||||
for i in range(len(dilation))
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
for c1, c2 in zip(self.convs1, self.convs2):
|
||||
xt = F.leaky_relu(x, LRELU_SLOPE)
|
||||
xt = c1(xt)
|
||||
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
||||
xt = c2(xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for layer in self.convs1:
|
||||
layer.remove_weight_norm()
|
||||
for layer in self.convs2:
|
||||
layer.remove_weight_norm()
|
||||
|
||||
|
||||
class Generator(torch.nn.Module):
|
||||
|
||||
def __init__(self, h):
|
||||
super(Generator, self).__init__()
|
||||
self.h = h
|
||||
self.num_kernels = len(h.resblock_kernel_sizes)
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
logger.info('num_kernels={}, num_upsamples={}'.format(
|
||||
self.num_kernels, self.num_upsamples))
|
||||
self.conv_pre = Conv1dCasual(
|
||||
80, h.upsample_initial_channel, 7, 1, padding=7 - 1)
|
||||
resblock = ResBlock1 if h.resblock == '1' else ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
self.repeat_ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(
|
||||
zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
||||
upsample = nn.Sequential(
|
||||
nn.Upsample(mode='nearest', scale_factor=u),
|
||||
nn.LeakyReLU(LRELU_SLOPE),
|
||||
Conv1dCasual(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
kernel_size=7,
|
||||
stride=1,
|
||||
padding=7 - 1))
|
||||
self.repeat_ups.append(upsample)
|
||||
self.ups.append(
|
||||
ConvTranspose1dCausal(
|
||||
h.upsample_initial_channel // (2**i),
|
||||
h.upsample_initial_channel // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2))
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2**(i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(h, ch, k, d))
|
||||
|
||||
self.conv_post = Conv1dCasual(ch, 1, 7, 1, padding=7 - 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = torch.sin(x) + x
|
||||
# transconv
|
||||
x1 = F.leaky_relu(x, LRELU_SLOPE)
|
||||
x1 = self.ups[i](x1)
|
||||
# repeat
|
||||
x2 = self.repeat_ups[i](x)
|
||||
x = x1 + x2
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
x = torch.tanh(x)
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
logger.info('Removing weight norm...')
|
||||
for layer in self.ups:
|
||||
layer.remove_weight_norm()
|
||||
for layer in self.repeat_ups:
|
||||
layer[-1].remove_weight_norm()
|
||||
for layer in self.resblocks:
|
||||
layer.remove_weight_norm()
|
||||
self.conv_pre.remove_weight_norm()
|
||||
self.conv_post.remove_weight_norm()
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .kantts_sambert import * # noqa F403
|
||||
@@ -1,718 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from modelscope.models.audio.tts.models.utils import get_mask_from_lengths
|
||||
from .adaptors import (LengthRegulator, VarFsmnRnnNARPredictor,
|
||||
VarRnnARPredictor)
|
||||
from .base import FFTBlock, PNCABlock, Prenet
|
||||
from .fsmn import FsmnEncoderV2
|
||||
from .positions import DurSinusoidalPositionEncoder, SinusoidalPositionEncoder
|
||||
|
||||
|
||||
class SelfAttentionEncoder(nn.Module):
|
||||
|
||||
def __init__(self, n_layer, d_in, d_model, n_head, d_head, d_inner,
|
||||
dropout, dropout_att, dropout_relu, position_encoder):
|
||||
super(SelfAttentionEncoder, self).__init__()
|
||||
|
||||
self.d_in = d_in
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
d_in_lst = [d_in] + [d_model] * (n_layer - 1)
|
||||
self.fft = nn.ModuleList([
|
||||
FFTBlock(d, d_model, n_head, d_head, d_inner, (3, 1), dropout,
|
||||
dropout_att, dropout_relu) for d in d_in_lst
|
||||
])
|
||||
self.ln = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.position_enc = position_encoder
|
||||
|
||||
def forward(self, input, mask=None, return_attns=False):
|
||||
input *= self.d_model**0.5
|
||||
if (isinstance(self.position_enc, SinusoidalPositionEncoder)):
|
||||
input = self.position_enc(input)
|
||||
else:
|
||||
raise NotImplementedError('modelscope error: position_enc invalid')
|
||||
|
||||
input = F.dropout(input, p=self.dropout, training=self.training)
|
||||
|
||||
enc_slf_attn_list = []
|
||||
max_len = input.size(1)
|
||||
if mask is not None:
|
||||
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
|
||||
else:
|
||||
slf_attn_mask = None
|
||||
|
||||
enc_output = input
|
||||
for id, layer in enumerate(self.fft):
|
||||
enc_output, enc_slf_attn = layer(
|
||||
enc_output, mask=mask, slf_attn_mask=slf_attn_mask)
|
||||
if return_attns:
|
||||
enc_slf_attn_list += [enc_slf_attn]
|
||||
|
||||
enc_output = self.ln(enc_output)
|
||||
|
||||
return enc_output, enc_slf_attn_list
|
||||
|
||||
|
||||
class HybridAttentionDecoder(nn.Module):
|
||||
|
||||
def __init__(self, d_in, prenet_units, n_layer, d_model, d_mem, n_head,
|
||||
d_head, d_inner, dropout, dropout_att, dropout_relu, d_out):
|
||||
super(HybridAttentionDecoder, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.dropout = dropout
|
||||
self.prenet = Prenet(d_in, prenet_units, d_model)
|
||||
self.dec_in_proj = nn.Linear(d_model + d_mem, d_model)
|
||||
self.pnca = nn.ModuleList([
|
||||
PNCABlock(d_model, d_mem, n_head, d_head, d_inner, (1, 1), dropout,
|
||||
dropout_att, dropout_relu) for _ in range(n_layer)
|
||||
])
|
||||
self.ln = nn.LayerNorm(d_model, eps=1e-6)
|
||||
self.dec_out_proj = nn.Linear(d_model, d_out)
|
||||
|
||||
def reset_state(self):
|
||||
for layer in self.pnca:
|
||||
layer.reset_state()
|
||||
|
||||
def get_pnca_attn_mask(self,
|
||||
device,
|
||||
max_len,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
mask=None):
|
||||
if mask is not None:
|
||||
pnca_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1)
|
||||
else:
|
||||
pnca_attn_mask = None
|
||||
|
||||
range_ = torch.arange(max_len).to(device)
|
||||
x_start = torch.clamp_min(range_ - x_band_width, 0)[None, None, :]
|
||||
x_end = (range_ + 1)[None, None, :]
|
||||
h_start = range_[None, None, :]
|
||||
h_end = torch.clamp_max(range_ + h_band_width + 1,
|
||||
max_len + 1)[None, None, :]
|
||||
|
||||
pnca_x_attn_mask = ~((x_start <= range_[None, :, None])
|
||||
& (x_end > range_[None, :, None])).transpose(1, 2) # yapf:disable
|
||||
pnca_h_attn_mask = ~((h_start <= range_[None, :, None])
|
||||
& (h_end > range_[None, :, None])).transpose(1, 2) # yapf:disable
|
||||
|
||||
if pnca_attn_mask is not None:
|
||||
pnca_x_attn_mask = (pnca_x_attn_mask | pnca_attn_mask)
|
||||
pnca_h_attn_mask = (pnca_h_attn_mask | pnca_attn_mask)
|
||||
pnca_x_attn_mask = pnca_x_attn_mask.masked_fill(
|
||||
pnca_attn_mask.transpose(1, 2), False)
|
||||
pnca_h_attn_mask = pnca_h_attn_mask.masked_fill(
|
||||
pnca_attn_mask.transpose(1, 2), False)
|
||||
|
||||
return pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask
|
||||
|
||||
# must call reset_state before
|
||||
def forward(self,
|
||||
input,
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
mask=None,
|
||||
return_attns=False):
|
||||
input = self.prenet(input)
|
||||
input = torch.cat([memory, input], dim=-1)
|
||||
input = self.dec_in_proj(input)
|
||||
|
||||
if mask is not None:
|
||||
input = input.masked_fill(mask.unsqueeze(-1), 0)
|
||||
|
||||
input *= self.d_model**0.5
|
||||
input = F.dropout(input, p=self.dropout, training=self.training)
|
||||
|
||||
max_len = input.size(1)
|
||||
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
|
||||
input.device, max_len, x_band_width, h_band_width, mask)
|
||||
|
||||
dec_pnca_attn_x_list = []
|
||||
dec_pnca_attn_h_list = []
|
||||
dec_output = input
|
||||
for id, layer in enumerate(self.pnca):
|
||||
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
|
||||
dec_output,
|
||||
memory,
|
||||
mask=mask,
|
||||
pnca_x_attn_mask=pnca_x_attn_mask,
|
||||
pnca_h_attn_mask=pnca_h_attn_mask)
|
||||
if return_attns:
|
||||
dec_pnca_attn_x_list += [dec_pnca_attn_x]
|
||||
dec_pnca_attn_h_list += [dec_pnca_attn_h]
|
||||
|
||||
dec_output = self.ln(dec_output)
|
||||
dec_output = self.dec_out_proj(dec_output)
|
||||
|
||||
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
|
||||
|
||||
# must call reset_state before when step == 0
|
||||
def infer(self,
|
||||
step,
|
||||
input,
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
mask=None,
|
||||
return_attns=False):
|
||||
max_len = memory.size(1)
|
||||
|
||||
input = self.prenet(input)
|
||||
input = torch.cat([memory[:, step:step + 1, :], input], dim=-1)
|
||||
input = self.dec_in_proj(input)
|
||||
|
||||
input *= self.d_model**0.5
|
||||
input = F.dropout(input, p=self.dropout, training=self.training)
|
||||
|
||||
pnca_attn_mask, pnca_x_attn_mask, pnca_h_attn_mask = self.get_pnca_attn_mask(
|
||||
input.device, max_len, x_band_width, h_band_width, mask)
|
||||
|
||||
dec_pnca_attn_x_list = []
|
||||
dec_pnca_attn_h_list = []
|
||||
dec_output = input
|
||||
for id, layer in enumerate(self.pnca):
|
||||
if mask is not None:
|
||||
mask_step = mask[:, step:step + 1]
|
||||
else:
|
||||
mask_step = None
|
||||
dec_output, dec_pnca_attn_x, dec_pnca_attn_h = layer(
|
||||
dec_output,
|
||||
memory,
|
||||
mask=mask_step,
|
||||
pnca_x_attn_mask=pnca_x_attn_mask[:,
|
||||
step:step + 1, :(step + 1)],
|
||||
pnca_h_attn_mask=pnca_h_attn_mask[:, step:step + 1, :])
|
||||
if return_attns:
|
||||
dec_pnca_attn_x_list += [dec_pnca_attn_x]
|
||||
dec_pnca_attn_h_list += [dec_pnca_attn_h]
|
||||
|
||||
dec_output = self.ln(dec_output)
|
||||
dec_output = self.dec_out_proj(dec_output)
|
||||
|
||||
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
|
||||
|
||||
|
||||
class TextFftEncoder(nn.Module):
|
||||
|
||||
def __init__(self, config, ling_unit_size):
|
||||
super(TextFftEncoder, self).__init__()
|
||||
|
||||
# linguistic unit lookup table
|
||||
nb_ling_sy = ling_unit_size['sy']
|
||||
nb_ling_tone = ling_unit_size['tone']
|
||||
nb_ling_syllable_flag = ling_unit_size['syllable_flag']
|
||||
nb_ling_ws = ling_unit_size['word_segment']
|
||||
|
||||
max_len = config['am']['max_len']
|
||||
|
||||
d_emb = config['am']['embedding_dim']
|
||||
nb_layers = config['am']['encoder_num_layers']
|
||||
nb_heads = config['am']['encoder_num_heads']
|
||||
d_model = config['am']['encoder_num_units']
|
||||
d_head = d_model // nb_heads
|
||||
d_inner = config['am']['encoder_ffn_inner_dim']
|
||||
dropout = config['am']['encoder_dropout']
|
||||
dropout_attn = config['am']['encoder_attention_dropout']
|
||||
dropout_relu = config['am']['encoder_relu_dropout']
|
||||
d_proj = config['am']['encoder_projection_units']
|
||||
|
||||
self.d_model = d_model
|
||||
|
||||
self.sy_emb = nn.Embedding(nb_ling_sy, d_emb)
|
||||
self.tone_emb = nn.Embedding(nb_ling_tone, d_emb)
|
||||
self.syllable_flag_emb = nn.Embedding(nb_ling_syllable_flag, d_emb)
|
||||
self.ws_emb = nn.Embedding(nb_ling_ws, d_emb)
|
||||
|
||||
position_enc = SinusoidalPositionEncoder(max_len, d_emb)
|
||||
|
||||
self.ling_enc = SelfAttentionEncoder(nb_layers, d_emb, d_model,
|
||||
nb_heads, d_head, d_inner,
|
||||
dropout, dropout_attn,
|
||||
dropout_relu, position_enc)
|
||||
|
||||
self.ling_proj = nn.Linear(d_model, d_proj, bias=False)
|
||||
|
||||
def forward(self, inputs_ling, masks=None, return_attns=False):
|
||||
# Parse inputs_ling_seq
|
||||
inputs_sy = inputs_ling[:, :, 0]
|
||||
inputs_tone = inputs_ling[:, :, 1]
|
||||
inputs_syllable_flag = inputs_ling[:, :, 2]
|
||||
inputs_ws = inputs_ling[:, :, 3]
|
||||
|
||||
# Lookup table
|
||||
sy_embedding = self.sy_emb(inputs_sy)
|
||||
tone_embedding = self.tone_emb(inputs_tone)
|
||||
syllable_flag_embedding = self.syllable_flag_emb(inputs_syllable_flag)
|
||||
ws_embedding = self.ws_emb(inputs_ws)
|
||||
|
||||
ling_embedding = sy_embedding + tone_embedding + syllable_flag_embedding + ws_embedding
|
||||
|
||||
enc_output, enc_slf_attn_list = self.ling_enc(ling_embedding, masks,
|
||||
return_attns)
|
||||
|
||||
enc_output = self.ling_proj(enc_output)
|
||||
|
||||
return enc_output, enc_slf_attn_list
|
||||
|
||||
|
||||
class VarianceAdaptor(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(VarianceAdaptor, self).__init__()
|
||||
|
||||
input_dim = config['am']['encoder_projection_units'] + config['am'][
|
||||
'emotion_units'] + config['am']['speaker_units']
|
||||
filter_size = config['am']['predictor_filter_size']
|
||||
fsmn_num_layers = config['am']['predictor_fsmn_num_layers']
|
||||
num_memory_units = config['am']['predictor_num_memory_units']
|
||||
ffn_inner_dim = config['am']['predictor_ffn_inner_dim']
|
||||
dropout = config['am']['predictor_dropout']
|
||||
shift = config['am']['predictor_shift']
|
||||
lstm_units = config['am']['predictor_lstm_units']
|
||||
|
||||
dur_pred_prenet_units = config['am']['dur_pred_prenet_units']
|
||||
dur_pred_lstm_units = config['am']['dur_pred_lstm_units']
|
||||
|
||||
self.pitch_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size,
|
||||
fsmn_num_layers,
|
||||
num_memory_units,
|
||||
ffn_inner_dim, dropout,
|
||||
shift, lstm_units)
|
||||
self.energy_predictor = VarFsmnRnnNARPredictor(input_dim, filter_size,
|
||||
fsmn_num_layers,
|
||||
num_memory_units,
|
||||
ffn_inner_dim, dropout,
|
||||
shift, lstm_units)
|
||||
self.duration_predictor = VarRnnARPredictor(input_dim,
|
||||
dur_pred_prenet_units,
|
||||
dur_pred_lstm_units)
|
||||
|
||||
self.length_regulator = LengthRegulator(
|
||||
config['am']['outputs_per_step'])
|
||||
self.dur_position_encoder = DurSinusoidalPositionEncoder(
|
||||
config['am']['encoder_projection_units'],
|
||||
config['am']['outputs_per_step'])
|
||||
|
||||
self.pitch_emb = nn.Conv1d(
|
||||
1,
|
||||
config['am']['encoder_projection_units'],
|
||||
kernel_size=9,
|
||||
padding=4)
|
||||
self.energy_emb = nn.Conv1d(
|
||||
1,
|
||||
config['am']['encoder_projection_units'],
|
||||
kernel_size=9,
|
||||
padding=4)
|
||||
|
||||
def forward(self,
|
||||
inputs_text_embedding,
|
||||
inputs_emo_embedding,
|
||||
inputs_spk_embedding,
|
||||
masks=None,
|
||||
output_masks=None,
|
||||
duration_targets=None,
|
||||
pitch_targets=None,
|
||||
energy_targets=None):
|
||||
|
||||
batch_size = inputs_text_embedding.size(0)
|
||||
|
||||
variance_predictor_inputs = torch.cat([
|
||||
inputs_text_embedding, inputs_spk_embedding, inputs_emo_embedding
|
||||
], dim=-1) # yapf:disable
|
||||
|
||||
pitch_predictions = self.pitch_predictor(variance_predictor_inputs,
|
||||
masks)
|
||||
energy_predictions = self.energy_predictor(variance_predictor_inputs,
|
||||
masks)
|
||||
|
||||
if pitch_targets is not None:
|
||||
pitch_embeddings = self.pitch_emb(
|
||||
pitch_targets.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
pitch_embeddings = self.pitch_emb(
|
||||
pitch_predictions.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
if energy_targets is not None:
|
||||
energy_embeddings = self.energy_emb(
|
||||
energy_targets.unsqueeze(1)).transpose(1, 2)
|
||||
else:
|
||||
energy_embeddings = self.energy_emb(
|
||||
energy_predictions.unsqueeze(1)).transpose(1, 2)
|
||||
|
||||
inputs_text_embedding_aug = inputs_text_embedding + pitch_embeddings + energy_embeddings
|
||||
duration_predictor_cond = torch.cat([
|
||||
inputs_text_embedding_aug, inputs_spk_embedding,
|
||||
inputs_emo_embedding
|
||||
], dim=-1) # yapf:disable
|
||||
if duration_targets is not None:
|
||||
duration_predictor_go_frame = torch.zeros(batch_size, 1).to(
|
||||
inputs_text_embedding.device)
|
||||
duration_predictor_input = torch.cat([
|
||||
duration_predictor_go_frame, duration_targets[:, :-1].float()
|
||||
], dim=-1) # yapf:disable
|
||||
duration_predictor_input = torch.log(duration_predictor_input + 1)
|
||||
log_duration_predictions, _ = self.duration_predictor(
|
||||
duration_predictor_input.unsqueeze(-1),
|
||||
duration_predictor_cond,
|
||||
masks=masks)
|
||||
duration_predictions = torch.exp(log_duration_predictions) - 1
|
||||
else:
|
||||
log_duration_predictions = self.duration_predictor.infer(
|
||||
duration_predictor_cond, masks=masks)
|
||||
duration_predictions = torch.exp(log_duration_predictions) - 1
|
||||
|
||||
if duration_targets is not None:
|
||||
LR_text_outputs, LR_length_rounded = self.length_regulator(
|
||||
inputs_text_embedding_aug,
|
||||
duration_targets,
|
||||
masks=output_masks)
|
||||
LR_position_embeddings = self.dur_position_encoder(
|
||||
duration_targets, masks=output_masks)
|
||||
LR_emo_outputs, _ = self.length_regulator(
|
||||
inputs_emo_embedding, duration_targets, masks=output_masks)
|
||||
LR_spk_outputs, _ = self.length_regulator(
|
||||
inputs_spk_embedding, duration_targets, masks=output_masks)
|
||||
|
||||
else:
|
||||
LR_text_outputs, LR_length_rounded = self.length_regulator(
|
||||
inputs_text_embedding_aug,
|
||||
duration_predictions,
|
||||
masks=output_masks)
|
||||
LR_position_embeddings = self.dur_position_encoder(
|
||||
duration_predictions, masks=output_masks)
|
||||
LR_emo_outputs, _ = self.length_regulator(
|
||||
inputs_emo_embedding, duration_predictions, masks=output_masks)
|
||||
LR_spk_outputs, _ = self.length_regulator(
|
||||
inputs_spk_embedding, duration_predictions, masks=output_masks)
|
||||
|
||||
LR_text_outputs = LR_text_outputs + LR_position_embeddings
|
||||
|
||||
return (LR_text_outputs, LR_emo_outputs, LR_spk_outputs,
|
||||
LR_length_rounded, log_duration_predictions, pitch_predictions,
|
||||
energy_predictions)
|
||||
|
||||
|
||||
class MelPNCADecoder(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(MelPNCADecoder, self).__init__()
|
||||
|
||||
prenet_units = config['am']['decoder_prenet_units']
|
||||
nb_layers = config['am']['decoder_num_layers']
|
||||
nb_heads = config['am']['decoder_num_heads']
|
||||
d_model = config['am']['decoder_num_units']
|
||||
d_head = d_model // nb_heads
|
||||
d_inner = config['am']['decoder_ffn_inner_dim']
|
||||
dropout = config['am']['decoder_dropout']
|
||||
dropout_attn = config['am']['decoder_attention_dropout']
|
||||
dropout_relu = config['am']['decoder_relu_dropout']
|
||||
outputs_per_step = config['am']['outputs_per_step']
|
||||
|
||||
d_mem = config['am'][
|
||||
'encoder_projection_units'] * outputs_per_step + config['am'][
|
||||
'emotion_units'] + config['am']['speaker_units']
|
||||
d_mel = config['am']['num_mels']
|
||||
|
||||
self.d_mel = d_mel
|
||||
self.r = outputs_per_step
|
||||
self.nb_layers = nb_layers
|
||||
|
||||
self.mel_dec = HybridAttentionDecoder(d_mel, prenet_units, nb_layers,
|
||||
d_model, d_mem, nb_heads, d_head,
|
||||
d_inner, dropout, dropout_attn,
|
||||
dropout_relu,
|
||||
d_mel * outputs_per_step)
|
||||
|
||||
def forward(self,
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
target=None,
|
||||
mask=None,
|
||||
return_attns=False):
|
||||
batch_size = memory.size(0)
|
||||
go_frame = torch.zeros((batch_size, 1, self.d_mel)).to(memory.device)
|
||||
|
||||
if target is not None:
|
||||
self.mel_dec.reset_state()
|
||||
input = target[:, self.r - 1::self.r, :]
|
||||
input = torch.cat([go_frame, input], dim=1)[:, :-1, :]
|
||||
dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list = self.mel_dec(
|
||||
input,
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
mask=mask,
|
||||
return_attns=return_attns)
|
||||
|
||||
else:
|
||||
dec_output = []
|
||||
dec_pnca_attn_x_list = [[] for _ in range(self.nb_layers)]
|
||||
dec_pnca_attn_h_list = [[] for _ in range(self.nb_layers)]
|
||||
self.mel_dec.reset_state()
|
||||
input = go_frame
|
||||
for step in range(memory.size(1)):
|
||||
dec_output_step, dec_pnca_attn_x_step, dec_pnca_attn_h_step = self.mel_dec.infer(
|
||||
step,
|
||||
input,
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
mask=mask,
|
||||
return_attns=return_attns)
|
||||
input = dec_output_step[:, :, -self.d_mel:]
|
||||
|
||||
dec_output.append(dec_output_step)
|
||||
for layer_id, (pnca_x_attn, pnca_h_attn) in enumerate(
|
||||
zip(dec_pnca_attn_x_step, dec_pnca_attn_h_step)):
|
||||
left = memory.size(1) - pnca_x_attn.size(-1)
|
||||
if (left > 0):
|
||||
padding = torch.zeros(
|
||||
(pnca_x_attn.size(0), 1, left)).to(pnca_x_attn)
|
||||
pnca_x_attn = torch.cat([pnca_x_attn, padding], dim=-1)
|
||||
dec_pnca_attn_x_list[layer_id].append(pnca_x_attn)
|
||||
dec_pnca_attn_h_list[layer_id].append(pnca_h_attn)
|
||||
|
||||
dec_output = torch.cat(dec_output, dim=1)
|
||||
for layer_id in range(self.nb_layers):
|
||||
dec_pnca_attn_x_list[layer_id] = torch.cat(
|
||||
dec_pnca_attn_x_list[layer_id], dim=1)
|
||||
dec_pnca_attn_h_list[layer_id] = torch.cat(
|
||||
dec_pnca_attn_h_list[layer_id], dim=1)
|
||||
|
||||
return dec_output, dec_pnca_attn_x_list, dec_pnca_attn_h_list
|
||||
|
||||
|
||||
class PostNet(nn.Module):
|
||||
|
||||
def __init__(self, config):
|
||||
super(PostNet, self).__init__()
|
||||
|
||||
self.filter_size = config['am']['postnet_filter_size']
|
||||
self.fsmn_num_layers = config['am']['postnet_fsmn_num_layers']
|
||||
self.num_memory_units = config['am']['postnet_num_memory_units']
|
||||
self.ffn_inner_dim = config['am']['postnet_ffn_inner_dim']
|
||||
self.dropout = config['am']['postnet_dropout']
|
||||
self.shift = config['am']['postnet_shift']
|
||||
self.lstm_units = config['am']['postnet_lstm_units']
|
||||
self.num_mels = config['am']['num_mels']
|
||||
|
||||
self.fsmn = FsmnEncoderV2(self.filter_size, self.fsmn_num_layers,
|
||||
self.num_mels, self.num_memory_units,
|
||||
self.ffn_inner_dim, self.dropout, self.shift)
|
||||
self.lstm = nn.LSTM(
|
||||
self.num_memory_units,
|
||||
self.lstm_units,
|
||||
num_layers=1,
|
||||
batch_first=True)
|
||||
self.fc = nn.Linear(self.lstm_units, self.num_mels)
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
postnet_fsmn_output = self.fsmn(x, mask)
|
||||
# The input can also be a packed variable length sequence,
|
||||
# here we just omit it for simpliciy due to the mask and uni-directional lstm.
|
||||
postnet_lstm_output, _ = self.lstm(postnet_fsmn_output)
|
||||
mel_residual_output = self.fc(postnet_lstm_output)
|
||||
|
||||
return mel_residual_output
|
||||
|
||||
|
||||
def mel_recon_loss_fn(output_lengths,
|
||||
mel_targets,
|
||||
dec_outputs,
|
||||
postnet_outputs=None):
|
||||
mae_loss = nn.L1Loss(reduction='none')
|
||||
|
||||
output_masks = get_mask_from_lengths(
|
||||
output_lengths, max_len=mel_targets.size(1))
|
||||
output_masks = ~output_masks
|
||||
valid_outputs = output_masks.sum()
|
||||
|
||||
mel_loss_ = torch.sum(
|
||||
mae_loss(mel_targets, dec_outputs) * output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
|
||||
if postnet_outputs is not None:
|
||||
mel_loss = torch.sum(
|
||||
mae_loss(mel_targets, postnet_outputs)
|
||||
* output_masks.unsqueeze(-1)) / (
|
||||
valid_outputs * mel_targets.size(-1))
|
||||
else:
|
||||
mel_loss = 0.0
|
||||
|
||||
return mel_loss_, mel_loss
|
||||
|
||||
|
||||
def prosody_recon_loss_fn(input_lengths, duration_targets, pitch_targets,
|
||||
energy_targets, log_duration_predictions,
|
||||
pitch_predictions, energy_predictions):
|
||||
mae_loss = nn.L1Loss(reduction='none')
|
||||
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=duration_targets.size(1))
|
||||
input_masks = ~input_masks
|
||||
valid_inputs = input_masks.sum()
|
||||
|
||||
dur_loss = torch.sum(
|
||||
mae_loss(
|
||||
torch.log(duration_targets.float() + 1), log_duration_predictions)
|
||||
* input_masks) / valid_inputs
|
||||
pitch_loss = torch.sum(
|
||||
mae_loss(pitch_targets, pitch_predictions)
|
||||
* input_masks) / valid_inputs
|
||||
energy_loss = torch.sum(
|
||||
mae_loss(energy_targets, energy_predictions)
|
||||
* input_masks) / valid_inputs
|
||||
|
||||
return dur_loss, pitch_loss, energy_loss
|
||||
|
||||
|
||||
class KanTtsSAMBERT(nn.Module):
|
||||
|
||||
def __init__(self, config, ling_unit_size):
|
||||
super(KanTtsSAMBERT, self).__init__()
|
||||
|
||||
self.text_encoder = TextFftEncoder(config, ling_unit_size)
|
||||
self.spk_tokenizer = nn.Embedding(ling_unit_size['speaker'],
|
||||
config['am']['speaker_units'])
|
||||
self.emo_tokenizer = nn.Embedding(ling_unit_size['emotion'],
|
||||
config['am']['emotion_units'])
|
||||
self.variance_adaptor = VarianceAdaptor(config)
|
||||
self.mel_decoder = MelPNCADecoder(config)
|
||||
self.mel_postnet = PostNet(config)
|
||||
|
||||
def get_lfr_mask_from_lengths(self, lengths, max_len):
|
||||
batch_size = lengths.size(0)
|
||||
# padding according to the outputs_per_step
|
||||
padded_lr_lengths = torch.zeros_like(lengths)
|
||||
for i in range(batch_size):
|
||||
len_item = int(lengths[i].item())
|
||||
padding = self.mel_decoder.r - len_item % self.mel_decoder.r
|
||||
if (padding < self.mel_decoder.r):
|
||||
padded_lr_lengths[i] = (len_item
|
||||
+ padding) // self.mel_decoder.r
|
||||
else:
|
||||
padded_lr_lengths[i] = len_item // self.mel_decoder.r
|
||||
|
||||
return get_mask_from_lengths(
|
||||
padded_lr_lengths, max_len=max_len // self.mel_decoder.r)
|
||||
|
||||
def forward(self,
|
||||
inputs_ling,
|
||||
inputs_emotion,
|
||||
inputs_speaker,
|
||||
input_lengths,
|
||||
output_lengths=None,
|
||||
mel_targets=None,
|
||||
duration_targets=None,
|
||||
pitch_targets=None,
|
||||
energy_targets=None):
|
||||
|
||||
batch_size = inputs_ling.size(0)
|
||||
|
||||
input_masks = get_mask_from_lengths(
|
||||
input_lengths, max_len=inputs_ling.size(1))
|
||||
|
||||
text_hid, enc_sla_attn_lst = self.text_encoder(
|
||||
inputs_ling, input_masks, return_attns=True)
|
||||
|
||||
emo_hid = self.emo_tokenizer(inputs_emotion)
|
||||
spk_hid = self.spk_tokenizer(inputs_speaker)
|
||||
|
||||
if output_lengths is not None:
|
||||
output_masks = get_mask_from_lengths(
|
||||
output_lengths, max_len=mel_targets.size(1))
|
||||
else:
|
||||
output_masks = None
|
||||
|
||||
(LR_text_outputs, LR_emo_outputs, LR_spk_outputs, LR_length_rounded,
|
||||
log_duration_predictions, pitch_predictions,
|
||||
energy_predictions) = self.variance_adaptor(
|
||||
text_hid,
|
||||
emo_hid,
|
||||
spk_hid,
|
||||
masks=input_masks,
|
||||
output_masks=output_masks,
|
||||
duration_targets=duration_targets,
|
||||
pitch_targets=pitch_targets,
|
||||
energy_targets=energy_targets)
|
||||
|
||||
if output_lengths is not None:
|
||||
lfr_masks = self.get_lfr_mask_from_lengths(
|
||||
output_lengths, max_len=LR_text_outputs.size(1))
|
||||
else:
|
||||
output_masks = get_mask_from_lengths(
|
||||
LR_length_rounded, max_len=LR_text_outputs.size(1))
|
||||
lfr_masks = None
|
||||
|
||||
# LFR with the factor of outputs_per_step
|
||||
LFR_text_inputs = LR_text_outputs.contiguous().view(
|
||||
batch_size, -1, self.mel_decoder.r * text_hid.shape[-1])
|
||||
LFR_emo_inputs = LR_emo_outputs.contiguous().view(
|
||||
batch_size, -1,
|
||||
self.mel_decoder.r * emo_hid.shape[-1])[:, :, :emo_hid.shape[-1]]
|
||||
LFR_spk_inputs = LR_spk_outputs.contiguous().view(
|
||||
batch_size, -1,
|
||||
self.mel_decoder.r * spk_hid.shape[-1])[:, :, :spk_hid.shape[-1]]
|
||||
|
||||
memory = torch.cat([LFR_text_inputs, LFR_spk_inputs, LFR_emo_inputs],
|
||||
dim=-1)
|
||||
|
||||
if duration_targets is not None:
|
||||
x_band_width = int(
|
||||
duration_targets.float().masked_fill(input_masks, 0).max()
|
||||
/ self.mel_decoder.r + 0.5)
|
||||
h_band_width = x_band_width
|
||||
else:
|
||||
x_band_width = int((torch.exp(log_duration_predictions) - 1).max()
|
||||
/ self.mel_decoder.r + 0.5)
|
||||
h_band_width = x_band_width
|
||||
|
||||
dec_outputs, pnca_x_attn_lst, pnca_h_attn_lst = self.mel_decoder(
|
||||
memory,
|
||||
x_band_width,
|
||||
h_band_width,
|
||||
target=mel_targets,
|
||||
mask=lfr_masks,
|
||||
return_attns=True)
|
||||
|
||||
# De-LFR with the factor of outputs_per_step
|
||||
dec_outputs = dec_outputs.contiguous().view(batch_size, -1,
|
||||
self.mel_decoder.d_mel)
|
||||
|
||||
if output_masks is not None:
|
||||
dec_outputs = dec_outputs.masked_fill(
|
||||
output_masks.unsqueeze(-1), 0)
|
||||
|
||||
postnet_outputs = self.mel_postnet(dec_outputs,
|
||||
output_masks) + dec_outputs
|
||||
if output_masks is not None:
|
||||
postnet_outputs = postnet_outputs.masked_fill(
|
||||
output_masks.unsqueeze(-1), 0)
|
||||
|
||||
res = {
|
||||
'x_band_width': x_band_width,
|
||||
'h_band_width': h_band_width,
|
||||
'enc_slf_attn_lst': enc_sla_attn_lst,
|
||||
'pnca_x_attn_lst': pnca_x_attn_lst,
|
||||
'pnca_h_attn_lst': pnca_h_attn_lst,
|
||||
'dec_outputs': dec_outputs,
|
||||
'postnet_outputs': postnet_outputs,
|
||||
'LR_length_rounded': LR_length_rounded,
|
||||
'log_duration_predictions': log_duration_predictions,
|
||||
'pitch_predictions': pitch_predictions,
|
||||
'energy_predictions': energy_predictions
|
||||
}
|
||||
|
||||
res['LR_text_outputs'] = LR_text_outputs
|
||||
res['LR_emo_outputs'] = LR_emo_outputs
|
||||
res['LR_spk_outputs'] = LR_spk_outputs
|
||||
|
||||
return res
|
||||
@@ -1,3 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
from .utils import * # noqa F403
|
||||
@@ -1,136 +0,0 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import glob
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pylab as plt
|
||||
import torch
|
||||
|
||||
matplotlib.use('Agg')
|
||||
|
||||
|
||||
class AttrDict(dict):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(AttrDict, self).__init__(*args, **kwargs)
|
||||
self.__dict__ = self
|
||||
|
||||
|
||||
def build_env(config, config_name, path):
|
||||
t_path = os.path.join(path, config_name)
|
||||
if config != t_path:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
shutil.copyfile(config, os.path.join(path, config_name))
|
||||
|
||||
|
||||
def plot_spectrogram(spectrogram):
|
||||
fig, ax = plt.subplots(figsize=(10, 2))
|
||||
im = ax.imshow(
|
||||
spectrogram, aspect='auto', origin='lower', interpolation='none')
|
||||
plt.colorbar(im, ax=ax)
|
||||
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def plot_alignment(alignment, info=None):
|
||||
fig, ax = plt.subplots()
|
||||
im = ax.imshow(
|
||||
alignment, aspect='auto', origin='lower', interpolation='none')
|
||||
fig.colorbar(im, ax=ax)
|
||||
xlabel = 'Input timestep'
|
||||
if info is not None:
|
||||
xlabel += '\t' + info
|
||||
plt.xlabel(xlabel)
|
||||
plt.ylabel('Output timestep')
|
||||
fig.canvas.draw()
|
||||
plt.close()
|
||||
|
||||
return fig
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
assert os.path.isfile(filepath)
|
||||
checkpoint_dict = torch.load(filepath, map_location=device)
|
||||
return checkpoint_dict
|
||||
|
||||
|
||||
def save_checkpoint(filepath, obj):
|
||||
torch.save(obj, filepath)
|
||||
|
||||
|
||||
def scan_checkpoint(cp_dir, prefix):
|
||||
pattern = os.path.join(cp_dir, prefix + '????????.pkl')
|
||||
cp_list = glob.glob(pattern)
|
||||
if len(cp_list) == 0:
|
||||
return None
|
||||
return sorted(cp_list)[-1]
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
class ValueWindow():
|
||||
|
||||
def __init__(self, window_size=100):
|
||||
self._window_size = window_size
|
||||
self._values = []
|
||||
|
||||
def append(self, x):
|
||||
self._values = self._values[-(self._window_size - 1):] + [x]
|
||||
|
||||
@property
|
||||
def sum(self):
|
||||
return sum(self._values)
|
||||
|
||||
@property
|
||||
def count(self):
|
||||
return len(self._values)
|
||||
|
||||
@property
|
||||
def average(self):
|
||||
return self.sum / max(1, self.count)
|
||||
|
||||
def reset(self):
|
||||
self._values = []
|
||||
|
||||
|
||||
def get_model_size(model):
|
||||
param_num = sum([p.numel() for p in model.parameters() if p.requires_grad])
|
||||
param_size = param_num * 4 / 1024 / 1024
|
||||
return param_size
|
||||
|
||||
|
||||
def get_grad_norm(model):
|
||||
total_norm = 0
|
||||
params = [
|
||||
p for p in model.parameters() if p.grad is not None and p.requires_grad
|
||||
]
|
||||
for p in params:
|
||||
param_norm = p.grad.detach().data.norm(2)
|
||||
total_norm += param_norm.item()**2
|
||||
total_norm = total_norm**0.5
|
||||
return total_norm
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find('Conv') != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
def get_mask_from_lengths(lengths, max_len=None):
|
||||
batch_size = lengths.shape[0]
|
||||
if max_len is None:
|
||||
max_len = torch.max(lengths).item()
|
||||
|
||||
ids = torch.arange(0, max_len).unsqueeze(0).expand(batch_size,
|
||||
-1).to(lengths.device)
|
||||
mask = ids >= lengths.unsqueeze(1).expand(-1, max_len)
|
||||
|
||||
return mask
|
||||
@@ -2,24 +2,31 @@
|
||||
|
||||
from __future__ import (absolute_import, division, print_function,
|
||||
unicode_literals)
|
||||
import datetime
|
||||
import os
|
||||
import shutil
|
||||
import zipfile
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import yaml
|
||||
|
||||
from modelscope.metainfo import Models
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.models.builder import MODELS
|
||||
from modelscope.utils.audio.audio_utils import TtsTrainType
|
||||
from modelscope.utils.audio.tts_exceptions import (
|
||||
TtsFrontendInitializeFailedException,
|
||||
TtsFrontendLanguageTypeInvalidException, TtsModelConfigurationException,
|
||||
TtsVoiceNotExistsException)
|
||||
from modelscope.utils.constant import Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from .voice import Voice
|
||||
|
||||
__all__ = ['SambertHifigan']
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@MODELS.register_module(
|
||||
Tasks.text_to_speech, module_name=Models.sambert_hifigan)
|
||||
@@ -27,18 +34,12 @@ class SambertHifigan(Model):
|
||||
|
||||
def __init__(self, model_dir, *args, **kwargs):
|
||||
super().__init__(model_dir, *args, **kwargs)
|
||||
if 'am' not in kwargs:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: configuration model field missing am!')
|
||||
if 'vocoder' not in kwargs:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: configuration model field missing vocoder!')
|
||||
if 'lang_type' not in kwargs:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: configuration model field missing lang_type!'
|
||||
)
|
||||
am_cfg = kwargs['am']
|
||||
voc_cfg = kwargs['vocoder']
|
||||
self.__model_dir = model_dir
|
||||
self.__is_train = False
|
||||
if 'is_train' in kwargs:
|
||||
is_train = kwargs['is_train']
|
||||
if isinstance(is_train, bool):
|
||||
self.__is_train = is_train
|
||||
# initialize frontend
|
||||
import ttsfrd
|
||||
frontend = ttsfrd.TtsFrontendEngine()
|
||||
@@ -55,33 +56,155 @@ class SambertHifigan(Model):
|
||||
'modelscope error: language type invalid: {}'.format(
|
||||
kwargs['lang_type']))
|
||||
self.__frontend = frontend
|
||||
zip_file = os.path.join(model_dir, 'voices.zip')
|
||||
self.__voice_path = os.path.join(model_dir, 'voices')
|
||||
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
|
||||
zip_ref.extractall(model_dir)
|
||||
voice_cfg_path = os.path.join(self.__voice_path, 'voices.json')
|
||||
with open(voice_cfg_path, 'r', encoding='utf-8') as f:
|
||||
voice_cfg = json.load(f)
|
||||
if 'voices' not in voice_cfg:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: voices invalid')
|
||||
self.__voice = {}
|
||||
for name in voice_cfg['voices']:
|
||||
voice_path = os.path.join(self.__voice_path, name)
|
||||
if not os.path.exists(voice_path):
|
||||
continue
|
||||
self.__voice[name] = Voice(name, voice_path, am_cfg, voc_cfg)
|
||||
if voice_cfg['voices']:
|
||||
self.__default_voice_name = voice_cfg['voices'][0]
|
||||
self.__voices, self.__voice_cfg = self.load_voice(model_dir)
|
||||
if len(self.__voices) == 0 or len(self.__voice_cfg) == 0:
|
||||
raise TtsVoiceNotExistsException('modelscope error: voices empty')
|
||||
if self.__voice_cfg['voices']:
|
||||
self.__default_voice_name = self.__voice_cfg['voices'][0]
|
||||
else:
|
||||
raise TtsVoiceNotExistsException(
|
||||
'modelscope error: voices is empty in voices.json')
|
||||
|
||||
def load_voice(self, model_dir):
|
||||
voices = {}
|
||||
voices_path = os.path.join(model_dir, 'voices')
|
||||
voices_json_path = os.path.join(voices_path, 'voices.json')
|
||||
if not os.path.exists(voices_path) or not os.path.exists(
|
||||
voices_json_path):
|
||||
return voices, []
|
||||
with open(voices_json_path, 'r', encoding='utf-8') as f:
|
||||
voice_cfg = json.load(f)
|
||||
if 'voices' not in voice_cfg:
|
||||
return voices, []
|
||||
for name in voice_cfg['voices']:
|
||||
voice_path = os.path.join(voices_path, name)
|
||||
if not os.path.exists(voice_path):
|
||||
continue
|
||||
voices[name] = Voice(name, voice_path)
|
||||
return voices, voice_cfg
|
||||
|
||||
def save_voices(self):
|
||||
voices_json_path = os.path.join(self.__model_dir, 'voices',
|
||||
'voices.json')
|
||||
if os.path.exists(voices_json_path):
|
||||
os.remove(voices_json_path)
|
||||
save_voices = {}
|
||||
save_voices['voices'] = []
|
||||
for k in self.__voices.keys():
|
||||
save_voices['voices'].append(k)
|
||||
with open(voices_json_path, 'w', encoding='utf-8') as f:
|
||||
json.dump(save_voices, f)
|
||||
|
||||
def get_voices(self):
|
||||
return self.__voices, self.__voice_cfg
|
||||
|
||||
def create_empty_voice(self, voice_name, audio_config, am_config_path,
|
||||
voc_config_path):
|
||||
voice_name_path = os.path.join(self.__model_dir, 'voices', voice_name)
|
||||
if os.path.exists(voice_name_path):
|
||||
shutil.rmtree(voice_name_path)
|
||||
os.makedirs(voice_name_path, exist_ok=True)
|
||||
if audio_config and os.path.exists(audio_config) and os.path.isfile(
|
||||
audio_config):
|
||||
shutil.copy(audio_config, voice_name_path)
|
||||
voice_am_path = os.path.join(voice_name_path, 'am')
|
||||
voice_voc_path = os.path.join(voice_name_path, 'voc')
|
||||
if am_config_path and os.path.exists(
|
||||
am_config_path) and os.path.isfile(am_config):
|
||||
am_config_name = os.path.join(voice_am_path, 'config.yaml')
|
||||
shutil.copy(am_config_path, am_config_name)
|
||||
if voc_config_path and os.path.exists(
|
||||
voc_config_path) and os.path.isfile(voc_config):
|
||||
voc_config_name = os.path.join(voice_am_path, 'config.yaml')
|
||||
shutil.copy(voc_config_path, voc_config_name)
|
||||
am_ckpt_path = os.path.join(voice_am_path, 'ckpt')
|
||||
voc_ckpt_path = os.path.join(voice_voc_path, 'ckpt')
|
||||
os.makedirs(am_ckpt_path, exist_ok=True)
|
||||
os.makedirs(voc_ckpt_path, exist_ok=True)
|
||||
self.__voices[voice_name] = Voice(
|
||||
voice_name=voice_name,
|
||||
voice_path=voice_name_path,
|
||||
allow_empty=True)
|
||||
|
||||
def get_voice_audio_config_path(self, voice):
|
||||
if voice not in self.__voices:
|
||||
return ''
|
||||
return self.__voices[voice].audio_config
|
||||
|
||||
def get_voice_lang_path(self, voice):
|
||||
if voice not in self.__voices:
|
||||
return ''
|
||||
return self.__voices[voice].lang_dir
|
||||
|
||||
def __synthesis_one_sentences(self, voice_name, text):
|
||||
if voice_name not in self.__voice:
|
||||
if voice_name not in self.__voices:
|
||||
raise TtsVoiceNotExistsException(
|
||||
f'modelscope error: Voice {voice_name} not exists')
|
||||
return self.__voice[voice_name].forward(text)
|
||||
return self.__voices[voice_name].forward(text)
|
||||
|
||||
def train(self,
|
||||
voice,
|
||||
dirs,
|
||||
train_type,
|
||||
configs_path=None,
|
||||
ignore_pretrain=False,
|
||||
create_if_not_exists=False,
|
||||
hparam=None):
|
||||
work_dir = dirs['work_dir']
|
||||
am_dir = dirs['am_tmp_dir']
|
||||
voc_dir = dirs['voc_tmp_dir']
|
||||
data_dir = dirs['data_dir']
|
||||
|
||||
if voice not in self.__voices:
|
||||
if not create_if_not_exists:
|
||||
raise TtsVoiceNotExistsException(
|
||||
f'modelscope error: Voice {voice_name} not exists')
|
||||
am_config = configs_path.get('am_config', None)
|
||||
voc_config = configs_path.get('voc_config', None)
|
||||
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type and not am_config:
|
||||
raise TtsTrainingCfgNotExistsException(
|
||||
'training new voice am with empty am_config')
|
||||
if TtsTrainType.TRAIN_TYPE_VOC in train_type and not voc_config:
|
||||
raise TtsTrainingCfgNotExistsException(
|
||||
'training new voice voc with empty voc_config')
|
||||
|
||||
target_voice = self.__voices[voice]
|
||||
am_config_path = target_voice.am_config
|
||||
voc_config_path = target_voice.voc_config
|
||||
if not configs_path:
|
||||
am_config = configs_path.get('am_config', None)
|
||||
if am_config:
|
||||
am_config_path = am_config
|
||||
voc_config = configs_path.get('voc_config', None)
|
||||
if voc_config:
|
||||
voc_config_path = voc_config
|
||||
|
||||
logger.info('Start training....')
|
||||
if TtsTrainType.TRAIN_TYPE_SAMBERT in train_type:
|
||||
logger.info('Start SAMBERT training...')
|
||||
totaltime = datetime.datetime.now()
|
||||
hparams = train_type[TtsTrainType.TRAIN_TYPE_SAMBERT]
|
||||
target_voice.train_sambert(work_dir, am_dir, data_dir,
|
||||
am_config_path, ignore_pretrain,
|
||||
hparams)
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('SAMBERT training spent: {:.2f} hours\n'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
else:
|
||||
logger.info('skip SAMBERT training...')
|
||||
|
||||
if TtsTrainType.TRAIN_TYPE_VOC in train_type:
|
||||
logger.info('Start HIFIGAN training...')
|
||||
totaltime = datetime.datetime.now()
|
||||
hparams = train_type[TtsTrainType.TRAIN_TYPE_VOC]
|
||||
target_voice.train_hifigan(work_dir, voc_dir, data_dir,
|
||||
voc_config_path, ignore_pretrain,
|
||||
hparams)
|
||||
totaltime = datetime.datetime.now() - totaltime
|
||||
logger.info('HIFIGAN training spent: {:.2f} hours\n'.format(
|
||||
totaltime.total_seconds() / 3600.0))
|
||||
else:
|
||||
logger.info('skip HIFIGAN training...')
|
||||
|
||||
def forward(self, text: str, voice_name: str = None):
|
||||
voice = self.__default_voice_name
|
||||
|
||||
@@ -2,75 +2,121 @@
|
||||
|
||||
import os
|
||||
import pickle as pkl
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from threading import Lock
|
||||
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import yaml
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from modelscope.utils.audio.tts_exceptions import \
|
||||
TtsModelConfigurationException
|
||||
from modelscope import __version__
|
||||
from modelscope.utils.audio.tts_exceptions import (
|
||||
TtsModelConfigurationException, TtsModelNotExistsException)
|
||||
from modelscope.utils.constant import ModelFile, Tasks
|
||||
from .models.datasets.units import KanTtsLinguisticUnit
|
||||
from .models.models.hifigan import Generator
|
||||
from .models.models.sambert import KanTtsSAMBERT
|
||||
from .models.utils import (AttrDict, build_env, init_weights, load_checkpoint,
|
||||
plot_spectrogram, save_checkpoint, scan_checkpoint)
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
from modelscope.models.audio.tts.kantts import ( # isort:skip; isort:skip
|
||||
GAN_Trainer, Generator, KanTtsLinguisticUnit, Sambert_Trainer,
|
||||
criterion_builder, get_am_datasets, get_voc_datasets, model_builder)
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class Voice:
|
||||
|
||||
def __init__(self, voice_name, voice_path, am_config, voc_config):
|
||||
def __init__(self, voice_name, voice_path, allow_empty=False):
|
||||
self.__voice_name = voice_name
|
||||
self.__voice_path = voice_path
|
||||
self.__am_config = AttrDict(**am_config)
|
||||
self.__voc_config = AttrDict(**voc_config)
|
||||
self.distributed = False
|
||||
self.local_rank = 0
|
||||
am_config_path = os.path.join(
|
||||
os.path.join(voice_path, 'am'), 'config.yaml')
|
||||
voc_config_path = os.path.join(
|
||||
os.path.join(voice_path, 'voc'), 'config.yaml')
|
||||
|
||||
self.audio_config = os.path.join(voice_path, 'audio_config.yaml')
|
||||
self.lang_dir = os.path.join(voice_path, 'dict')
|
||||
self.am_config = am_config_path
|
||||
self.voc_config = voc_config_path
|
||||
|
||||
am_ckpt = os.path.join(os.path.join(voice_path, 'am'), 'ckpt')
|
||||
voc_ckpt = os.path.join(os.path.join(voice_path, 'voc'), 'ckpt')
|
||||
|
||||
self.__am_ckpts = self.scan_ckpt(am_ckpt)
|
||||
self.__voc_ckpts = self.scan_ckpt(voc_ckpt)
|
||||
|
||||
if not os.path.exists(am_config_path):
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: am configuration not found')
|
||||
if not os.path.exists(voc_config_path):
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: voc configuration not found')
|
||||
if not allow_empty:
|
||||
if len(self.__am_ckpts) == 0:
|
||||
raise TtsModelNotExistsException(
|
||||
'modelscope error: am model file not found')
|
||||
if len(self.__voc_ckpts) == 0:
|
||||
raise TtsModelNotExistsException(
|
||||
'modelscope error: voc model file not found')
|
||||
with open(am_config_path, 'r') as f:
|
||||
self.__am_config = yaml.load(f, Loader=yaml.Loader)
|
||||
with open(voc_config_path, 'r') as f:
|
||||
self.__voc_config = yaml.load(f, Loader=yaml.Loader)
|
||||
self.__model_loaded = False
|
||||
self.__lock = Lock()
|
||||
if 'am' not in self.__am_config:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: am configuration invalid')
|
||||
if 'linguistic_unit' not in self.__am_config:
|
||||
raise TtsModelConfigurationException(
|
||||
'modelscope error: am configuration invalid')
|
||||
self.__am_lingustic_unit_config = self.__am_config['linguistic_unit']
|
||||
self.__ling_unit = KanTtsLinguisticUnit(self.__am_config,
|
||||
self.lang_dir)
|
||||
self.__ling_unit_size = self.__ling_unit.get_unit_size()
|
||||
self.__am_config['Model']['KanTtsSAMBERT']['params'].update(
|
||||
self.__ling_unit_size)
|
||||
if torch.cuda.is_available():
|
||||
self.__device = torch.device('cuda')
|
||||
else:
|
||||
self.__device = torch.device('cpu')
|
||||
|
||||
def scan_ckpt(self, ckpt_path):
|
||||
filelist = os.listdir(ckpt_path)
|
||||
if len(filelist) == 0:
|
||||
return {}
|
||||
ckpts = {}
|
||||
for filename in filelist:
|
||||
# checkpoint_X.pth
|
||||
if len(filename) - 15 <= 0:
|
||||
continue
|
||||
if filename[-4:] == '.pth' and filename[0:10] == 'checkpoint':
|
||||
filename_prefix = filename.split('.')[0]
|
||||
idx = int(filename_prefix.split('_')[-1])
|
||||
path = os.path.join(ckpt_path, filename)
|
||||
ckpts[idx] = path
|
||||
od = OrderedDict(sorted(ckpts.items()))
|
||||
return od
|
||||
|
||||
def __load_am(self):
|
||||
local_am_ckpt_path = os.path.join(self.__voice_path, 'am')
|
||||
self.__am_ckpt_path = os.path.join(local_am_ckpt_path,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
has_mask = True
|
||||
if 'has_mask' in self.__am_lingustic_unit_config:
|
||||
has_mask = self.__am_lingustic_unit_config.has_mask
|
||||
self.__ling_unit = KanTtsLinguisticUnit(
|
||||
self.__am_lingustic_unit_config, self.__voice_path, has_mask)
|
||||
self.__am_net = KanTtsSAMBERT(self.__am_config,
|
||||
self.__ling_unit.get_unit_size()).to(
|
||||
self.__device)
|
||||
state_dict_g = {}
|
||||
try:
|
||||
state_dict_g = load_checkpoint(self.__am_ckpt_path, self.__device)
|
||||
except RuntimeError:
|
||||
with open(self.__am_ckpt_path, 'rb') as f:
|
||||
pth_var_dict = pkl.load(f)
|
||||
state_dict_g['fsnet'] = {
|
||||
k: torch.FloatTensor(v)
|
||||
for k, v in pth_var_dict['fsnet'].items()
|
||||
}
|
||||
self.__am_net.load_state_dict(state_dict_g['fsnet'], strict=False)
|
||||
self.__am_net.eval()
|
||||
self.__am_model, _, _ = model_builder(self.__am_config, self.__device)
|
||||
self.__am = self.__am_model['KanTtsSAMBERT']
|
||||
state_dict = torch.load(self.__am_ckpts[next(
|
||||
reversed(self.__am_ckpts))])
|
||||
self.__am.load_state_dict(state_dict['model'], strict=False)
|
||||
self.__am.eval()
|
||||
|
||||
def __load_vocoder(self):
|
||||
local_voc_ckpy_path = os.path.join(self.__voice_path, 'vocoder')
|
||||
self.__voc_ckpt_path = os.path.join(local_voc_ckpy_path,
|
||||
ModelFile.TORCH_MODEL_BIN_FILE)
|
||||
self.__generator = Generator(self.__voc_config).to(self.__device)
|
||||
state_dict_g = load_checkpoint(self.__voc_ckpt_path, self.__device)
|
||||
self.__generator.load_state_dict(state_dict_g['generator'])
|
||||
self.__generator.eval()
|
||||
self.__generator.remove_weight_norm()
|
||||
self.__voc_model = Generator(
|
||||
**self.__voc_config['Model']['Generator']['params'])
|
||||
states = torch.load(self.__voc_ckpts[next(reversed(self.__voc_ckpts))])
|
||||
self.__voc_model.load_state_dict(states['model']['generator'])
|
||||
if self.__voc_config['Model']['Generator']['params'][
|
||||
'out_channels'] > 1:
|
||||
from .kantts.models.pqmf import PQMF
|
||||
self.__voc_model = PQMF()
|
||||
self.__voc_model.remove_weight_norm()
|
||||
self.__voc_model.eval().to(self.__device)
|
||||
|
||||
def __am_forward(self, symbol_seq):
|
||||
with self.__lock:
|
||||
@@ -92,43 +138,283 @@ class Voice:
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_spk = torch.from_numpy(inputs_feat_lst[5]).long().to(
|
||||
self.__device).unsqueeze(0)
|
||||
inputs_len = torch.zeros(1).to(self.__device).long(
|
||||
) + inputs_emo.size(1) - 1 # minus 1 for "~"
|
||||
res = self.__am_net(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
|
||||
inputs_spk[:, :-1], inputs_len)
|
||||
inputs_len = (torch.zeros(1).to(self.__device).long()
|
||||
+ inputs_emo.size(1) - 1) # minus 1 for "~"
|
||||
res = self.__am(inputs_ling[:, :-1, :], inputs_emo[:, :-1],
|
||||
inputs_spk[:, :-1], inputs_len)
|
||||
postnet_outputs = res['postnet_outputs']
|
||||
LR_length_rounded = res['LR_length_rounded']
|
||||
valid_length = int(LR_length_rounded[0].item())
|
||||
postnet_outputs = postnet_outputs[
|
||||
0, :valid_length, :].cpu().numpy()
|
||||
postnet_outputs = postnet_outputs[0, :valid_length, :].cpu()
|
||||
return postnet_outputs
|
||||
|
||||
def __vocoder_forward(self, melspec):
|
||||
dim0 = list(melspec.shape)[-1]
|
||||
if dim0 != self.__voc_config.num_mels:
|
||||
raise TtsVocoderMelspecShapeMismatchException(
|
||||
'modelscope error: input melspec mismatch require {} but {}'.
|
||||
format(self.__voc_config.num_mels, dim0))
|
||||
with torch.no_grad():
|
||||
x = melspec.T
|
||||
x = torch.FloatTensor(x).to(self.__device)
|
||||
if len(x.shape) == 2:
|
||||
x = x.unsqueeze(0)
|
||||
y_g_hat = self.__generator(x)
|
||||
audio = y_g_hat.squeeze()
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype('int16')
|
||||
return audio
|
||||
x = melspec.to(self.__device)
|
||||
x = x.transpose(1, 0).unsqueeze(0)
|
||||
y = self.__voc_model(x)
|
||||
if hasattr(self.__voc_model, 'pqmf'):
|
||||
y = self.__voc_model.synthesis(y)
|
||||
y = y.view(-1).cpu().numpy()
|
||||
return y
|
||||
|
||||
def train_sambert(self,
|
||||
work_dir,
|
||||
stage_dir,
|
||||
data_dir,
|
||||
config_path,
|
||||
ignore_pretrain=False,
|
||||
hparams=dict()):
|
||||
logger.info('TRAIN SAMBERT....')
|
||||
if len(self.__am_ckpts) == 0:
|
||||
raise TtsTrainingInvalidModelException(
|
||||
'resume pretrain but model is empty')
|
||||
|
||||
from_steps = hparams.get('resume_from_steps', -1)
|
||||
if from_steps < 0:
|
||||
from_latest = hparams.get('resume_from_latest', True)
|
||||
else:
|
||||
from_latest = hparams.get('resume_from_latest', False)
|
||||
train_steps = hparams.get('train_steps', 0)
|
||||
|
||||
with open(self.audio_config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
config.update(yaml.load(f, Loader=yaml.Loader))
|
||||
config.update(hparams)
|
||||
|
||||
resume_from = None
|
||||
if from_latest:
|
||||
from_steps = next(reversed(self.__am_ckpts))
|
||||
resume_from = self.__am_ckpts[from_steps]
|
||||
if not os.path.exists(resume_from):
|
||||
raise TtsTrainingInvalidModelException(
|
||||
f'latest model:{resume_from} not exists')
|
||||
else:
|
||||
if from_steps not in self.__am_ckpts:
|
||||
raise TtsTrainingInvalidModelException(
|
||||
f'no such model from steps:{from_steps}')
|
||||
else:
|
||||
resume_from = self.__am_ckpts[from_steps]
|
||||
|
||||
if train_steps > 0:
|
||||
train_max_steps = train_steps + from_steps
|
||||
config['train_max_steps'] = train_max_steps
|
||||
|
||||
logger.info(f'TRAINING steps: {train_max_steps}')
|
||||
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime())
|
||||
config['modelscope_version'] = __version__
|
||||
|
||||
with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
|
||||
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
|
||||
|
||||
for key, value in config.items():
|
||||
logger.info(f'{key} = {value}')
|
||||
|
||||
fp_enable = config['Model']['KanTtsSAMBERT']['params'].get('FP', False)
|
||||
meta_file = [
|
||||
os.path.join(
|
||||
d,
|
||||
'raw_metafile.txt' if not fp_enable else 'fprm_metafile.txt')
|
||||
for d in data_dir
|
||||
]
|
||||
|
||||
train_dataset, valid_dataset = get_am_datasets(meta_file, data_dir,
|
||||
self.lang_dir, config,
|
||||
config['allow_cache'])
|
||||
|
||||
logger.info(f'The number of training files = {len(train_dataset)}.')
|
||||
logger.info(f'The number of validation files = {len(valid_dataset)}.')
|
||||
|
||||
sampler = {'train': None, 'valid': None}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=False if self.distributed else True,
|
||||
collate_fn=train_dataset.collate_fn,
|
||||
batch_size=config['batch_size'],
|
||||
num_workers=config['num_workers'],
|
||||
sampler=sampler['train'],
|
||||
pin_memory=config['pin_memory'],
|
||||
)
|
||||
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
shuffle=False if self.distributed else True,
|
||||
collate_fn=valid_dataset.collate_fn,
|
||||
batch_size=config['batch_size'],
|
||||
num_workers=config['num_workers'],
|
||||
sampler=sampler['valid'],
|
||||
pin_memory=config['pin_memory'],
|
||||
)
|
||||
|
||||
ling_unit_size = train_dataset.ling_unit.get_unit_size()
|
||||
|
||||
config['Model']['KanTtsSAMBERT']['params'].update(ling_unit_size)
|
||||
model, optimizer, scheduler = model_builder(config, self.__device,
|
||||
self.local_rank,
|
||||
self.distributed)
|
||||
|
||||
criterion = criterion_builder(config, self.__device)
|
||||
|
||||
trainer = Sambert_Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
criterion=criterion,
|
||||
device=self.__device,
|
||||
sampler=sampler,
|
||||
train_loader=train_dataloader,
|
||||
valid_loader=valid_dataloader,
|
||||
max_steps=train_max_steps,
|
||||
save_dir=stage_dir,
|
||||
save_interval=config['save_interval_steps'],
|
||||
valid_interval=config['eval_interval_steps'],
|
||||
log_interval=config['log_interval'],
|
||||
grad_clip=config['grad_norm'],
|
||||
)
|
||||
|
||||
if resume_from is not None:
|
||||
trainer.load_checkpoint(resume_from, True, True)
|
||||
logger.info(f'Successfully resumed from {resume_from}.')
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
logger.error(e, exc_info=True)
|
||||
trainer.save_checkpoint(
|
||||
os.path.join(
|
||||
os.path.join(stage_dir, 'ckpt'),
|
||||
f'checkpoint-{trainer.steps}.pth'))
|
||||
logger.info(
|
||||
f'Successfully saved checkpoint @ {trainer.steps}steps.')
|
||||
|
||||
def train_hifigan(self,
|
||||
work_dir,
|
||||
stage_dir,
|
||||
data_dir,
|
||||
config_path,
|
||||
ignore_pretrain=False,
|
||||
hparams=dict()):
|
||||
logger.info('TRAIN HIFIGAN....')
|
||||
if len(self.__voc_ckpts) == 0:
|
||||
raise TtsTrainingInvalidModelException(
|
||||
'resume pretrain but model is empty')
|
||||
|
||||
from_steps = hparams.get('resume_from_steps', -1)
|
||||
if from_steps < 0:
|
||||
from_latest = hparams.get('resume_from_latest', True)
|
||||
else:
|
||||
from_latest = hparams.get('resume_from_latest', False)
|
||||
train_steps = hparams.get('train_steps', 0)
|
||||
|
||||
with open(self.audio_config, 'r') as f:
|
||||
config = yaml.load(f, Loader=yaml.Loader)
|
||||
|
||||
with open(config_path, 'r') as f:
|
||||
config.update(yaml.load(f, Loader=yaml.Loader))
|
||||
config.update(hparams)
|
||||
|
||||
resume_from = None
|
||||
if from_latest:
|
||||
from_steps = next(reversed(self.__voc_ckpts))
|
||||
resume_from = self.__voc_ckpts[from_steps]
|
||||
if not os.path.exists(resume_from):
|
||||
raise TtsTrainingInvalidModelException(
|
||||
f'latest model:{resume_from} not exists')
|
||||
else:
|
||||
if from_steps not in self.__voc_ckpts:
|
||||
raise TtsTrainingInvalidModelException(
|
||||
f'no such model from steps:{from_steps}')
|
||||
else:
|
||||
resume_from = self.__voc_ckpts[from_steps]
|
||||
|
||||
if train_steps > 0:
|
||||
train_max_steps = train_steps
|
||||
config['train_max_steps'] = train_max_steps
|
||||
|
||||
logger.info(f'TRAINING steps: {train_max_steps}')
|
||||
logger.info(f'resume from: {resume_from}')
|
||||
config['create_time'] = time.strftime('%Y-%m-%d %H:%M:%S',
|
||||
time.localtime())
|
||||
config['modelscope_version'] = __version__
|
||||
|
||||
with open(os.path.join(stage_dir, 'config.yaml'), 'w') as f:
|
||||
yaml.dump(config, f, Dumper=yaml.Dumper, default_flow_style=None)
|
||||
|
||||
for key, value in config.items():
|
||||
logger.info(f'{key} = {value}')
|
||||
|
||||
train_dataset, valid_dataset = get_voc_datasets(config, data_dir)
|
||||
|
||||
logger.info(f'The number of training files = {len(train_dataset)}.')
|
||||
logger.info(f'The number of validation files = {len(valid_dataset)}.')
|
||||
|
||||
sampler = {'train': None, 'valid': None}
|
||||
|
||||
train_dataloader = DataLoader(
|
||||
train_dataset,
|
||||
shuffle=False if self.distributed else True,
|
||||
collate_fn=train_dataset.collate_fn,
|
||||
batch_size=config['batch_size'],
|
||||
num_workers=config['num_workers'],
|
||||
sampler=sampler['train'],
|
||||
pin_memory=config['pin_memory'],
|
||||
)
|
||||
|
||||
valid_dataloader = DataLoader(
|
||||
valid_dataset,
|
||||
shuffle=False if self.distributed else True,
|
||||
collate_fn=valid_dataset.collate_fn,
|
||||
batch_size=config['batch_size'],
|
||||
num_workers=config['num_workers'],
|
||||
sampler=sampler['valid'],
|
||||
pin_memory=config['pin_memory'],
|
||||
)
|
||||
|
||||
model, optimizer, scheduler = model_builder(config, self.__device,
|
||||
self.local_rank,
|
||||
self.distributed)
|
||||
|
||||
criterion = criterion_builder(config, self.__device)
|
||||
trainer = GAN_Trainer(
|
||||
config=config,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
scheduler=scheduler,
|
||||
criterion=criterion,
|
||||
device=self.__device,
|
||||
sampler=sampler,
|
||||
train_loader=train_dataloader,
|
||||
valid_loader=valid_dataloader,
|
||||
max_steps=train_max_steps,
|
||||
save_dir=stage_dir,
|
||||
save_interval=config['save_interval_steps'],
|
||||
valid_interval=config['eval_interval_steps'],
|
||||
log_interval=config['log_interval_steps'],
|
||||
)
|
||||
|
||||
if resume_from is not None:
|
||||
trainer.load_checkpoint(resume_from)
|
||||
logger.info(f'Successfully resumed from {resume_from}.')
|
||||
|
||||
try:
|
||||
trainer.train()
|
||||
except (Exception, KeyboardInterrupt) as e:
|
||||
logger.error(e, exc_info=True)
|
||||
trainer.save_checkpoint(
|
||||
os.path.join(
|
||||
os.path.join(stage_dir, 'ckpt'),
|
||||
f'checkpoint-{trainer.steps}.pth'))
|
||||
logger.info(
|
||||
f'Successfully saved checkpoint @ {trainer.steps}steps.')
|
||||
|
||||
def forward(self, symbol_seq):
|
||||
with self.__lock:
|
||||
if not self.__model_loaded:
|
||||
torch.manual_seed(self.__am_config.seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.manual_seed(self.__am_config.seed)
|
||||
self.__device = torch.device('cuda')
|
||||
else:
|
||||
self.__device = torch.device('cpu')
|
||||
self.__load_am()
|
||||
self.__load_vocoder()
|
||||
self.__model_loaded = True
|
||||
|
||||
@@ -14,6 +14,7 @@ if TYPE_CHECKING:
|
||||
ImageInstanceSegmentationPreprocessor,
|
||||
ImageDenoisePreprocessor)
|
||||
from .kws import WavToLists
|
||||
from .tts import KanttsDataPreprocessor
|
||||
from .multi_modal import (OfaPreprocessor, MPlugPreprocessor)
|
||||
from .nlp import (
|
||||
DocumentSegmentationTransformersPreprocessor,
|
||||
@@ -50,6 +51,7 @@ else:
|
||||
'ImageInstanceSegmentationPreprocessor', 'ImageDenoisePreprocessor'
|
||||
],
|
||||
'kws': ['WavToLists'],
|
||||
'tts': ['KanttsDataPreprocessor'],
|
||||
'multi_modal': ['OfaPreprocessor', 'MPlugPreprocessor'],
|
||||
'nlp': [
|
||||
'DocumentSegmentationTransformersPreprocessor',
|
||||
|
||||
60
modelscope/preprocessors/tts.py
Normal file
60
modelscope/preprocessors/tts.py
Normal file
@@ -0,0 +1,60 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from modelscope.metainfo import Preprocessors
|
||||
from modelscope.models.audio.tts.kantts.preprocess.data_process import \
|
||||
process_data
|
||||
from modelscope.models.base import Model
|
||||
from modelscope.utils.audio.tts_exceptions import (
|
||||
TtsDataPreprocessorAudioConfigNotExistsException,
|
||||
TtsDataPreprocessorDirNotExistsException)
|
||||
from modelscope.utils.constant import Fields, Frameworks, Tasks
|
||||
from .base import Preprocessor
|
||||
from .builder import PREPROCESSORS
|
||||
|
||||
__all__ = ['KanttsDataPreprocessor']
|
||||
|
||||
|
||||
@PREPROCESSORS.register_module(
|
||||
group_key=Tasks.text_to_speech,
|
||||
module_name=Preprocessors.kantts_data_preprocessor)
|
||||
class KanttsDataPreprocessor(Preprocessor):
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(self,
|
||||
data_dir,
|
||||
output_dir,
|
||||
lang_dir,
|
||||
audio_config_path,
|
||||
speaker_name='F7',
|
||||
target_lang='PinYin',
|
||||
skip_script=False):
|
||||
self.do_data_process(data_dir, output_dir, lang_dir, audio_config_path,
|
||||
speaker_name, target_lang, skip_script)
|
||||
|
||||
def do_data_process(self,
|
||||
datadir,
|
||||
outputdir,
|
||||
langdir,
|
||||
audio_config,
|
||||
speaker_name='F7',
|
||||
targetLang='PinYin',
|
||||
skip_script=False):
|
||||
if not os.path.exists(datadir):
|
||||
raise TtsDataPreprocessorDirNotExistsException(
|
||||
'Preprocessor: dataset dir not exists')
|
||||
if not os.path.exists(outputdir):
|
||||
raise TtsDataPreprocessorDirNotExistsException(
|
||||
'Preprocessor: output dir not exists')
|
||||
if not os.path.exists(audio_config):
|
||||
raise TtsDataPreprocessorAudioConfigNotExistsException(
|
||||
'Preprocessor: audio config not exists')
|
||||
if not os.path.exists(langdir):
|
||||
raise TtsDataPreprocessorDirNotExistsException(
|
||||
'Preprocessor: language dir not exists')
|
||||
process_data(datadir, outputdir, langdir, audio_config, speaker_name,
|
||||
targetLang, skip_script)
|
||||
@@ -4,7 +4,7 @@ from typing import TYPE_CHECKING
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .audio.ans_trainer import ANSTrainer
|
||||
from .audio import ANSTrainer, KanttsTrainer
|
||||
from .base import DummyTrainer
|
||||
from .builder import build_trainer
|
||||
from .cv import (ImageInstanceSegmentationTrainer,
|
||||
@@ -18,7 +18,7 @@ if TYPE_CHECKING:
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'audio.ans_trainer': ['ANSTrainer'],
|
||||
'audio': ['ANSTrainer', 'KanttsTrainer'],
|
||||
'base': ['DummyTrainer'],
|
||||
'builder': ['build_trainer'],
|
||||
'cv': [
|
||||
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from modelscope.utils.import_utils import LazyImportModule
|
||||
|
||||
if TYPE_CHECKING:
|
||||
print('TYPE_CHECKING...')
|
||||
from .tts_trainer import KanttsTrainer
|
||||
from .ans_trainer import ANSTrainer
|
||||
|
||||
else:
|
||||
_import_structure = {
|
||||
'tts_trainer': ['KanttsTrainer'],
|
||||
'ans_trainer': ['ANSTrainer']
|
||||
}
|
||||
|
||||
import sys
|
||||
|
||||
sys.modules[__name__] = LazyImportModule(
|
||||
__name__,
|
||||
globals()['__file__'],
|
||||
_import_structure,
|
||||
module_spec=__spec__,
|
||||
extra_objects={},
|
||||
)
|
||||
|
||||
235
modelscope/trainers/audio/tts_trainer.py
Normal file
235
modelscope/trainers/audio/tts_trainer.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import json
|
||||
|
||||
from modelscope.metainfo import Preprocessors, Trainers
|
||||
from modelscope.models.audio.tts import SambertHifigan
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.preprocessors.builder import build_preprocessor
|
||||
from modelscope.trainers.base import BaseTrainer
|
||||
from modelscope.trainers.builder import TRAINERS
|
||||
from modelscope.utils.audio.audio_utils import TtsTrainType
|
||||
from modelscope.utils.audio.tts_exceptions import (
|
||||
TtsTrainingCfgNotExistsException, TtsTrainingDatasetInvalidException,
|
||||
TtsTrainingHparamsInvalidException, TtsTrainingInvalidModelException,
|
||||
TtsTrainingWorkDirNotExistsException)
|
||||
from modelscope.utils.constant import (DEFAULT_DATASET_NAMESPACE,
|
||||
DEFAULT_DATASET_REVISION,
|
||||
DEFAULT_MODEL_REVISION, ModelFile,
|
||||
Tasks, TrainerStages)
|
||||
from modelscope.utils.data_utils import to_device
|
||||
from modelscope.utils.logger import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
@TRAINERS.register_module(module_name=Trainers.speech_kantts_trainer)
|
||||
class KanttsTrainer(BaseTrainer):
|
||||
DATA_DIR = 'data'
|
||||
AM_TMP_DIR = 'tmp_am'
|
||||
VOC_TMP_DIR = 'tmp_voc'
|
||||
ORIG_MODEL_DIR = 'orig_model'
|
||||
|
||||
def __init__(self,
|
||||
model: str,
|
||||
work_dir: str = None,
|
||||
speaker: str = 'F7',
|
||||
lang_type: str = 'PinYin',
|
||||
cfg_file: Optional[str] = None,
|
||||
train_dataset: Optional[Union[MsDataset, str]] = None,
|
||||
train_dataset_namespace: str = DEFAULT_DATASET_NAMESPACE,
|
||||
train_dataset_revision: str = DEFAULT_DATASET_REVISION,
|
||||
train_type: dict = {
|
||||
TtsTrainType.TRAIN_TYPE_SAMBERT: {},
|
||||
TtsTrainType.TRAIN_TYPE_VOC: {}
|
||||
},
|
||||
preprocess_skip_script=False,
|
||||
model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
|
||||
**kwargs):
|
||||
|
||||
if not work_dir:
|
||||
self.work_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.work_dir):
|
||||
os.makedirs(self.work_dir)
|
||||
else:
|
||||
self.work_dir = work_dir
|
||||
|
||||
if not os.path.exists(self.work_dir):
|
||||
raise TtsTrainingWorkDirNotExistsException(
|
||||
f'{self.work_dir} not exists')
|
||||
|
||||
self.train_type = dict()
|
||||
if isinstance(train_type, dict):
|
||||
for k, v in train_type.items():
|
||||
if (k == TtsTrainType.TRAIN_TYPE_SAMBERT
|
||||
or k == TtsTrainType.TRAIN_TYPE_VOC
|
||||
or k == TtsTrainType.TRAIN_TYPE_BERT):
|
||||
self.train_type[k] = v
|
||||
|
||||
if len(self.train_type) == 0:
|
||||
logger.info('train type empty, default to sambert and voc')
|
||||
self.train_type[TtsTrainType.TRAIN_TYPE_SAMBERT] = {}
|
||||
self.train_type[TtsTrainType.TRAIN_TYPE_VOC] = {}
|
||||
|
||||
logger.info(f'Set workdir to {self.work_dir}')
|
||||
|
||||
self.data_dir = os.path.join(self.work_dir, self.DATA_DIR)
|
||||
self.am_tmp_dir = os.path.join(self.work_dir, self.AM_TMP_DIR)
|
||||
self.voc_tmp_dir = os.path.join(self.work_dir, self.VOC_TMP_DIR)
|
||||
self.orig_model_dir = os.path.join(self.work_dir, self.ORIG_MODEL_DIR)
|
||||
self.raw_dataset_path = ''
|
||||
self.skip_script = preprocess_skip_script
|
||||
self.audio_config_path = ''
|
||||
self.lang_path = ''
|
||||
self.am_config_path = ''
|
||||
self.voc_config_path = ''
|
||||
|
||||
shutil.rmtree(self.data_dir, ignore_errors=True)
|
||||
shutil.rmtree(self.am_tmp_dir, ignore_errors=True)
|
||||
shutil.rmtree(self.voc_tmp_dir, ignore_errors=True)
|
||||
shutil.rmtree(self.orig_model_dir, ignore_errors=True)
|
||||
|
||||
os.makedirs(self.data_dir)
|
||||
os.makedirs(self.am_tmp_dir)
|
||||
os.makedirs(self.voc_tmp_dir)
|
||||
|
||||
if train_dataset:
|
||||
if isinstance(train_dataset, str):
|
||||
logger.info(f'load {train_dataset_namespace}/{train_dataset}')
|
||||
train_dataset = MsDataset.load(
|
||||
dataset_name=train_dataset,
|
||||
namespace=train_dataset_namespace,
|
||||
version=train_dataset_revision)
|
||||
logger.info(f'train dataset:{train_dataset.config_kwargs}')
|
||||
self.raw_dataset_path = self.load_dataset_raw_path(train_dataset)
|
||||
model_dir = None
|
||||
if os.path.exists(model):
|
||||
model_dir = model
|
||||
else:
|
||||
model_dir = self.get_or_download_model_dir(model, model_revision)
|
||||
shutil.copytree(model_dir, self.orig_model_dir)
|
||||
self.model_dir = self.orig_model_dir
|
||||
|
||||
if not cfg_file:
|
||||
cfg_file = os.path.join(self.model_dir, ModelFile.CONFIGURATION)
|
||||
self.parse_cfg(cfg_file)
|
||||
|
||||
if not os.path.exists(self.raw_dataset_path):
|
||||
raise TtsTrainingDatasetInvalidException(
|
||||
'dataset raw path not exists')
|
||||
|
||||
self.finetune_from_pretrain = False
|
||||
self.speaker = speaker
|
||||
self.lang_type = lang_type
|
||||
self.model = None
|
||||
self.device = kwargs.get('device', 'gpu')
|
||||
self.model = self.get_model(self.model_dir, self.speaker,
|
||||
self.lang_type)
|
||||
if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
|
||||
self.audio_data_preprocessor = build_preprocessor(
|
||||
dict(type=Preprocessors.kantts_data_preprocessor),
|
||||
Tasks.text_to_speech)
|
||||
|
||||
def parse_cfg(self, cfg_file):
|
||||
cur_dir = os.path.dirname(cfg_file)
|
||||
with open(cfg_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
if 'train' not in config:
|
||||
raise TtsTrainingInvalidModelException(
|
||||
'model not support finetune')
|
||||
if 'audio_config' in config['train']:
|
||||
audio_config = os.path.join(cur_dir,
|
||||
config['train']['audio_config'])
|
||||
if os.path.exists(audio_config):
|
||||
self.audio_config_path = audio_config
|
||||
if 'am_config' in config['train']:
|
||||
am_config = os.path.join(cur_dir, config['train']['am_config'])
|
||||
if os.path.exists(am_config):
|
||||
self.am_config_path = am_config
|
||||
if 'voc_config' in config['train']:
|
||||
voc_config = os.path.join(cur_dir,
|
||||
config['train']['voc_config'])
|
||||
if os.path.exists(voc_config):
|
||||
self.voc_config_path = voc_config
|
||||
if 'language_path' in config['train']:
|
||||
lang_path = os.path.join(cur_dir,
|
||||
config['train']['language_path'])
|
||||
if os.path.exists(lang_path):
|
||||
self.lang_path = lang_path
|
||||
if not self.raw_dataset_path:
|
||||
if 'train_dataset' in config['train']:
|
||||
dataset = config['train']['train_dataset']
|
||||
if 'id' in dataset:
|
||||
namespace = dataset.get('namespace',
|
||||
DEFAULT_DATASET_NAMESPACE)
|
||||
revision = dataset.get('revision',
|
||||
DEFAULT_DATASET_REVISION)
|
||||
ms = MsDataset.load(
|
||||
dataset_name=dataset['id'],
|
||||
namespace=namespace,
|
||||
version=revision)
|
||||
self.raw_dataset_path = self.load_dataset_raw_path(ms)
|
||||
elif 'path' in dataset:
|
||||
self.raw_dataset_path = dataset['path']
|
||||
|
||||
def load_dataset_raw_path(self, dataset: MsDataset):
|
||||
if 'split_config' not in dataset.config_kwargs:
|
||||
raise TtsTrainingDatasetInvalidException(
|
||||
'split_config not found in config_kwargs')
|
||||
if 'train' not in dataset.config_kwargs['split_config']:
|
||||
raise TtsTrainingDatasetInvalidException(
|
||||
'no train split in split_config')
|
||||
return dataset.config_kwargs['split_config']['train']
|
||||
|
||||
def prepare_data(self):
|
||||
if self.audio_data_preprocessor:
|
||||
audio_config = self.audio_config_path
|
||||
if not audio_config or not os.path.exists(audio_config):
|
||||
audio_config = self.model.get_voice_audio_config_path(
|
||||
self.speaker)
|
||||
lang_path = self.lang_path
|
||||
if not lang_path or not os.path.exists(lang_path):
|
||||
lang_path = self.model.get_voice_lang_path(self.speaker)
|
||||
self.audio_data_preprocessor(self.raw_dataset_path, self.data_dir,
|
||||
lang_path, audio_config, self.speaker,
|
||||
self.lang_type, self.skip_script)
|
||||
|
||||
def prepare_text(self):
|
||||
pass
|
||||
|
||||
def get_model(self, model_dir, speaker, lang_type):
|
||||
model = SambertHifigan(
|
||||
model_dir=self.model_dir, lang_type=self.lang_type, is_train=True)
|
||||
return model
|
||||
|
||||
def train(self, *args, **kwargs):
|
||||
if not self.model:
|
||||
raise TtsTrainingInvalidModelException('model is none')
|
||||
ignore_pretrain = False
|
||||
if 'ignore_pretrain' in kwargs:
|
||||
ignore_pretrain = kwargs['ignore_pretrain']
|
||||
|
||||
if TtsTrainType.TRAIN_TYPE_SAMBERT in self.train_type or TtsTrainType.TRAIN_TYPE_VOC in self.train_type:
|
||||
self.prepare_data()
|
||||
if TtsTrainType.TRAIN_TYPE_BERT in self.train_type:
|
||||
self.prepare_text()
|
||||
dir_dict = {
|
||||
'work_dir': self.work_dir,
|
||||
'am_tmp_dir': self.am_tmp_dir,
|
||||
'voc_tmp_dir': self.voc_tmp_dir,
|
||||
'data_dir': self.data_dir
|
||||
}
|
||||
config_dict = {
|
||||
'am_config': self.am_config_path,
|
||||
'voc_config': self.voc_config_path
|
||||
}
|
||||
self.model.train(self.speaker, dir_dict, self.train_type, config_dict,
|
||||
ignore_pretrain)
|
||||
|
||||
def evaluate(self, checkpoint_path: str, *args,
|
||||
**kwargs) -> Dict[str, float]:
|
||||
return {}
|
||||
@@ -9,6 +9,12 @@ from modelscope.fileio.file import HTTPStorage
|
||||
SEGMENT_LENGTH_TRAIN = 16000
|
||||
|
||||
|
||||
class TtsTrainType(object):
|
||||
TRAIN_TYPE_SAMBERT = 'train-type-sambert'
|
||||
TRAIN_TYPE_BERT = 'train-type-bert'
|
||||
TRAIN_TYPE_VOC = 'train-type-voc'
|
||||
|
||||
|
||||
def to_segment(batch, segment_length=SEGMENT_LENGTH_TRAIN):
|
||||
"""
|
||||
Dataset mapping function to split one audio into segments.
|
||||
|
||||
@@ -18,6 +18,12 @@ class TtsModelConfigurationException(TtsException):
|
||||
pass
|
||||
|
||||
|
||||
class TtsModelNotExistsException(TtsException):
|
||||
"""
|
||||
TTS model not exists exception.
|
||||
"""
|
||||
|
||||
|
||||
class TtsVoiceNotExistsException(TtsException):
|
||||
"""
|
||||
TTS voice not exists exception.
|
||||
@@ -55,3 +61,58 @@ class TtsVocoderMelspecShapeMismatchException(TtsVocoderException):
|
||||
"""
|
||||
If vocoder's input melspec shape mismatch, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsDataPreprocessorException(TtsException):
|
||||
"""
|
||||
Tts data preprocess exception
|
||||
"""
|
||||
|
||||
|
||||
class TtsDataPreprocessorDirNotExistsException(TtsDataPreprocessorException):
|
||||
"""
|
||||
If any dir is not exists, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsDataPreprocessorAudioConfigNotExistsException(
|
||||
TtsDataPreprocessorException):
|
||||
"""
|
||||
If audio config is not exists, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingException(TtsException):
|
||||
"""
|
||||
Tts training exception
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingHparamsInvalidException(TtsException):
|
||||
"""
|
||||
If training hparams is invalid, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingWorkDirNotExistsException(TtsTrainingException):
|
||||
"""
|
||||
If training work dir not exists, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingCfgNotExistsException(TtsTrainingException):
|
||||
"""
|
||||
If training cfg not exists, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingDatasetInvalidException(TtsTrainingException):
|
||||
"""
|
||||
If dataset invalid, this exception will be raised.
|
||||
"""
|
||||
|
||||
|
||||
class TtsTrainingInvalidModelException(TtsTrainingException):
|
||||
"""
|
||||
If model is invalid or not exists, this exception will be raised.
|
||||
"""
|
||||
|
||||
@@ -1,22 +1,32 @@
|
||||
easyasr>=0.0.2
|
||||
espnet==202204
|
||||
funasr>=0.1.4
|
||||
greenlet>=1.1.2
|
||||
h5py
|
||||
inflect
|
||||
jedi>=0.18.1
|
||||
keras
|
||||
kwsbp>=0.0.2
|
||||
librosa
|
||||
lxml
|
||||
matplotlib
|
||||
MinDAEC
|
||||
msgpack>=1.0.4
|
||||
nara_wpe
|
||||
nltk
|
||||
# tensorflow 1.15 requires numpy<=1.18
|
||||
numpy<=1.18
|
||||
parso>=0.8.3
|
||||
pexpect>=4.8.0
|
||||
pickleshare>=0.7.5
|
||||
prompt-toolkit>=3.0.30
|
||||
# protobuf version beyond 3.20.0 is not compatible with TensorFlow 1.x, therefore is discouraged.
|
||||
protobuf>3,<3.21.0
|
||||
ptflops
|
||||
ptyprocess>=0.7.0
|
||||
py_sound_connect>=0.1
|
||||
pygments>=2.12.0
|
||||
pysptk>=0.1.15,<0.2.0
|
||||
pytorch_wavelets
|
||||
PyWavelets>=1.0.0
|
||||
scikit-learn
|
||||
@@ -24,5 +34,7 @@ SoundFile>0.10
|
||||
sox
|
||||
torchaudio
|
||||
tqdm
|
||||
traitlets>=5.3.0
|
||||
ttsfrd>=0.0.3
|
||||
unidecode
|
||||
wcwidth>=0.2.5
|
||||
|
||||
@@ -28,46 +28,69 @@ class TextToSpeechSambertHifigan16kPipelineTest(unittest.TestCase,
|
||||
self.task = Tasks.text_to_speech
|
||||
self.zhcn_text = '今天北京天气怎么样'
|
||||
self.en_text = 'How is the weather in Beijing?'
|
||||
self.zhcn_voices = [
|
||||
'zhitian_emo', 'zhizhe_emo', 'zhiyan_emo', 'zhibei_emo', 'zhcn'
|
||||
]
|
||||
self.zhcn_models = [
|
||||
'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k',
|
||||
'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k',
|
||||
'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k',
|
||||
'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k',
|
||||
'damo/speech_sambert-hifigan_tts_zh-cn_16k'
|
||||
]
|
||||
self.en_voices = ['luca', 'luna', 'andy', 'annie', 'engb', 'enus']
|
||||
self.en_models = [
|
||||
'damo/speech_sambert-hifigan_tts_luca_en-gb_16k',
|
||||
'damo/speech_sambert-hifigan_tts_luna_en-gb_16k',
|
||||
'damo/speech_sambert-hifigan_tts_andy_en-us_16k',
|
||||
'damo/speech_sambert-hifigan_tts_annie_en-us_16k',
|
||||
'damo/speech_sambert-hifigan_tts_en-gb_16k',
|
||||
'damo/speech_sambert-hifigan_tts_en-us_16k'
|
||||
self.test_model_name = [
|
||||
'pretrain_16k', 'pretrain_24k', 'zhitian_emo', 'zhizhe_emo',
|
||||
'zhiyan_emo', 'zhibei_emo', 'zhcn_16k', 'luca', 'luna', 'andy',
|
||||
'annie', 'engb_16k', 'enus_16k'
|
||||
]
|
||||
self.test_models = [{
|
||||
'model':
|
||||
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model':
|
||||
'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_24k',
|
||||
'text': self.zhcn_text,
|
||||
'sample_rate': 24000
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_zhitian_emo_zh-cn_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_zhizhe_emo_zh-cn_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_zhiyan_emo_zh-cn_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_zhibei_emo_zh-cn_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_zh-cn_16k',
|
||||
'text': self.zhcn_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_luca_en-gb_16k',
|
||||
'text': self.en_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_luna_en-gb_16k',
|
||||
'text': self.en_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_andy_en-us_16k',
|
||||
'text': self.en_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_annie_en-us_16k',
|
||||
'text': self.en_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_en-gb_16k',
|
||||
'text': self.en_text
|
||||
}, {
|
||||
'model': 'damo/speech_sambert-hifigan_tts_en-us_16k',
|
||||
'text': self.en_text
|
||||
}]
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_pipeline(self):
|
||||
for i in range(len(self.zhcn_voices)):
|
||||
logger.info('test %s' % self.zhcn_voices[i])
|
||||
for i in range(len(self.test_models)):
|
||||
logger.info('test %s' % self.test_model_name[i])
|
||||
sambert_hifigan_tts = pipeline(
|
||||
task=self.task, model=self.zhcn_models[i])
|
||||
task=self.task, model=self.test_models[i]['model'])
|
||||
self.assertTrue(sambert_hifigan_tts is not None)
|
||||
output = sambert_hifigan_tts(input=self.zhcn_text)
|
||||
output = sambert_hifigan_tts(input=self.test_models[i]['text'])
|
||||
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
|
||||
pcm = output[OutputKeys.OUTPUT_PCM]
|
||||
write('output_%s.wav' % self.zhcn_voices[i], 16000, pcm)
|
||||
for i in range(len(self.en_voices)):
|
||||
logger.info('test %s' % self.en_voices[i])
|
||||
sambert_hifigan_tts = pipeline(
|
||||
task=self.task, model=self.en_models[i])
|
||||
self.assertTrue(sambert_hifigan_tts is not None)
|
||||
output = sambert_hifigan_tts(input=self.en_text)
|
||||
self.assertIsNotNone(output[OutputKeys.OUTPUT_PCM])
|
||||
pcm = output[OutputKeys.OUTPUT_PCM]
|
||||
write('output_%s.wav' % self.en_voices[i], 16000, pcm)
|
||||
sr = 16000
|
||||
if 'sample_rate' in self.test_models[i]:
|
||||
sr = self.test_models[i]['sample_rate']
|
||||
write('output_%s.wav' % self.test_model_name[i], sr, pcm)
|
||||
|
||||
@unittest.skip('demo compatibility test is only enabled on a needed-basis')
|
||||
def test_demo_compatibility(self):
|
||||
|
||||
66
tests/trainers/audio/test_tts_trainer.py
Normal file
66
tests/trainers/audio/test_tts_trainer.py
Normal file
@@ -0,0 +1,66 @@
|
||||
# Copyright (c) Alibaba, Inc. and its affiliates.
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
from modelscope.metainfo import Trainers
|
||||
from modelscope.msdatasets import MsDataset
|
||||
from modelscope.trainers import build_trainer
|
||||
from modelscope.utils.audio.audio_utils import TtsTrainType
|
||||
from modelscope.utils.constant import DownloadMode, Fields, Tasks
|
||||
from modelscope.utils.logger import get_logger
|
||||
from modelscope.utils.test_utils import test_level
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
|
||||
class TestTtsTrainer(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.tmp_dir = tempfile.TemporaryDirectory().name
|
||||
if not os.path.exists(self.tmp_dir):
|
||||
os.makedirs(self.tmp_dir)
|
||||
|
||||
self.model_id = 'speech_tts/speech_sambert-hifigan_tts_zh-cn_multisp_pretrain_16k'
|
||||
self.dataset_id = 'speech_kantts_opendata'
|
||||
self.dataset_namespace = 'speech_tts'
|
||||
self.train_info = {
|
||||
TtsTrainType.TRAIN_TYPE_SAMBERT: {
|
||||
'train_steps': 2,
|
||||
'save_interval_steps': 1,
|
||||
'eval_interval_steps': 1,
|
||||
'log_interval': 1
|
||||
},
|
||||
TtsTrainType.TRAIN_TYPE_VOC: {
|
||||
'train_steps': 2,
|
||||
'save_interval_steps': 1,
|
||||
'eval_interval_steps': 1,
|
||||
'log_interval': 1
|
||||
}
|
||||
}
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tmp_dir, ignore_errors=True)
|
||||
super().tearDown()
|
||||
|
||||
@unittest.skipUnless(test_level() >= 0, 'skip test in current test level')
|
||||
def test_trainer(self):
|
||||
kwargs = dict(
|
||||
model=self.model_id,
|
||||
work_dir=self.tmp_dir,
|
||||
train_dataset=self.dataset_id,
|
||||
train_dataset_namespace=self.dataset_namespace,
|
||||
train_type=self.train_info)
|
||||
trainer = build_trainer(
|
||||
Trainers.speech_kantts_trainer, default_args=kwargs)
|
||||
trainer.train()
|
||||
tmp_am = os.path.join(self.tmp_dir, 'tmp_am', 'ckpt')
|
||||
tmp_voc = os.path.join(self.tmp_dir, 'tmp_voc', 'ckpt')
|
||||
assert os.path.exists(tmp_am)
|
||||
assert os.path.exists(tmp_voc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user